diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/pom.xml b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/pom.xml index 8deb0db59b9..57341998a8c 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/pom.xml +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/pom.xml @@ -91,6 +91,23 @@ junit-jupiter test + + + + org.springframework.ai + spring-ai-autoconfigure-model-tool + ${project.parent.version} + test + + + + + org.springframework.ai + spring-ai-autoconfigure-model-chat-client + ${project.parent.version} + test + + diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java new file mode 100644 index 00000000000..674f2663a5b --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-client-webflux/src/test/java/org/springframework/ai/mcp/client/webflux/autoconfigure/McpToolsConfigurationTests.java @@ -0,0 +1,268 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.client.webflux.autoconfigure; + +import java.util.List; + +import io.modelcontextprotocol.client.McpSyncClient; +import io.modelcontextprotocol.spec.McpSchema; +import org.junit.jupiter.api.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpSampling; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.model.ChatModel; +import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; +import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration; +import org.springframework.ai.model.chat.client.autoconfigure.ChatClientAutoConfiguration; +import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.resolution.ToolCallbackResolver; +import org.springframework.ai.util.json.schema.JsonSchemaGenerator; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.annotation.Bean; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.core.ResolvableType; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * @author Daniel Garnier-Moiroux + */ +class McpToolsConfigurationTests { + + /** + * Test that MCP tools have handlers configured when they use a chat client. This + * verifies that there is no cyclic dependency + * {@code McpClient -> @McpHandling -> ChatClient -> McpClient}. + */ + @Test + void mcpClientSupportsSampling() { + //@formatter:off + var clientApplicationContext = new ApplicationContextRunner() + .withUserConfiguration(TestMcpClientHandlers.class) + // Create a transport + .withPropertyValues("spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:0", + "spring.ai.mcp.client.initialized=false") + .withConfiguration(AutoConfigurations.of( + // Transport + StreamableHttpWebFluxTransportAutoConfiguration.class, + // MCP clients + McpToolCallbackAutoConfiguration.class, + McpClientAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, + McpClientSpecificationFactoryAutoConfiguration.class, + // Tool callbacks + ToolCallingAutoConfiguration.class, + // Chat client for sampling + ChatClientAutoConfiguration.class, + ChatModelAutoConfiguration.class + )); + //@formatter:on + clientApplicationContext.run(ctx -> { + // If the MCP callback provider is picked un by the + // ToolCallingAutoConfiguration, + // #getToolCallbacks will be called during the init phase, and try to call the + // MCP server + // There is no MCP server in this test, so the context would not even start. + String[] clients = ctx + .getBeanNamesForType(ResolvableType.forType(new ParameterizedTypeReference>() { + })); + assertThat(clients).hasSize(1); + List syncClients = (List) ctx.getBean(clients[0]); + assertThat(syncClients).hasSize(1) + .first() + .extracting(McpSyncClient::getClientCapabilities) + .extracting(McpSchema.ClientCapabilities::sampling) + .describedAs("Sampling") + .isNotNull(); + }); + } + + /** + * Ensure that MCP-related {@link ToolCallbackProvider}s do not get their + * {@code getToolCallbacks} method called on startup, and that, when possible, they + * are not injected into the default {@link ToolCallbackResolver}. + */ + @Test + void toolCallbacksRegistered() { + var clientApplicationContext = new ApplicationContextRunner() + .withUserConfiguration(TestToolCallbackConfiguration.class) + .withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)); + + clientApplicationContext.run(ctx -> { + // Observable behavior + var resolver = ctx.getBean(ToolCallbackResolver.class); + + // Resolves beans that are NOT mcp-related + assertThat(resolver.resolve("toolCallbackProvider")).isNotNull(); + assertThat(resolver.resolve("customToolCallbackProvider")).isNotNull(); + + // MCP toolcallback providers are never added to the resolver + + // Bean graph setup + var injectedProviders = (List) ctx.getBean( + "org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration.toolcallbackprovider.mcp-excluded"); + // Beans exposed as non-MCP + var toolCallbackProvider = (ToolCallbackProvider) ctx.getBean("toolCallbackProvider"); + var customToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("customToolCallbackProvider"); + // This is injected in the resolver bean, because it's exposed as a + // ToolCallbackProvider, but it's not added to the resolver + var genericMcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("genericMcpToolCallbackProvider"); + + // beans exposed as MCP + var mcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("mcpToolCallbackProvider"); + var customMcpToolCallbackProvider = (ToolCallbackProvider) ctx.getBean("customMcpToolCallbackProvider"); + + assertThat(injectedProviders) + .containsExactlyInAnyOrder(toolCallbackProvider, customToolCallbackProvider, + genericMcpToolCallbackProvider) + .doesNotContain(mcpToolCallbackProvider, customMcpToolCallbackProvider); + + }); + } + + static class TestMcpClientHandlers { + + private static final Logger logger = LoggerFactory.getLogger(TestMcpClientHandlers.class); + + private final ChatClient chatClient; + + TestMcpClientHandlers(ChatClient.Builder clientBuilder) { + this.chatClient = clientBuilder.build(); + } + + @McpSampling(clients = "server1") + McpSchema.CreateMessageResult samplingHandler(McpSchema.CreateMessageRequest llmRequest) { + logger.info("MCP SAMPLING: {}", llmRequest); + + String userPrompt = ((McpSchema.TextContent) llmRequest.messages().get(0).content()).text(); + String modelHint = llmRequest.modelPreferences().hints().get(0).name(); + // In a real use-case, we would use the chat client to call the LLM again + logger.info("MCP SAMPLING: simulating using chat client {}", this.chatClient); + + return McpSchema.CreateMessageResult.builder() + .content(new McpSchema.TextContent("Response " + userPrompt + " with model hint " + modelHint)) + .build(); + } + + } + + static class ChatModelAutoConfiguration { + + /** + * This is typically provided by a model-specific autoconfig, such as + * {@code AnthropicChatAutoConfiguration}. We do not need a full LLM in this test, + * so we mock out the chat model. + */ + @Bean + ChatModel chatModel() { + return mock(ChatModel.class); + } + + } + + static class TestToolCallbackConfiguration { + + @Bean + ToolCallbackProvider toolCallbackProvider() { + var tcp = mock(ToolCallbackProvider.class); + when(tcp.getToolCallbacks()).thenReturn(toolCallback("toolCallbackProvider")); + return tcp; + } + + // This bean depends on the resolver, to ensure there are no cyclic dependencies + @Bean + SyncMcpToolCallbackProvider mcpToolCallbackProvider(ToolCallbackResolver resolver) { + var tcp = mock(SyncMcpToolCallbackProvider.class); + when(tcp.getToolCallbacks()) + .thenThrow(new RuntimeException("mcpToolCallbackProvider#getToolCallbacks should not be called")); + return tcp; + } + + @Bean + CustomToolCallbackProvider customToolCallbackProvider() { + return new CustomToolCallbackProvider("customToolCallbackProvider"); + } + + // This bean depends on the resolver, to ensure there are no cyclic dependencies + @Bean + CustomMcpToolCallbackProvider customMcpToolCallbackProvider(ToolCallbackResolver resolver) { + return new CustomMcpToolCallbackProvider(); + } + + // This will be added to the resolver, because the visible type of the bean + // is ToolCallbackProvider ; we would need to actually instantiate the bean + // to find out that it is MCP-related + @Bean + ToolCallbackProvider genericMcpToolCallbackProvider() { + return new CustomMcpToolCallbackProvider(); + } + + static ToolCallback[] toolCallback(String name) { + return new ToolCallback[] { new ToolCallback() { + @Override + public ToolDefinition getToolDefinition() { + return ToolDefinition.builder() + .name(name) + .inputSchema(JsonSchemaGenerator.generateForType(String.class)) + .build(); + } + + @Override + public String call(String toolInput) { + return "~~ not implemented ~~"; + } + } }; + } + + static class CustomToolCallbackProvider implements ToolCallbackProvider { + + private final String name; + + CustomToolCallbackProvider(String name) { + this.name = name; + } + + @Override + public ToolCallback[] getToolCallbacks() { + return toolCallback(this.name); + } + + } + + static class CustomMcpToolCallbackProvider extends SyncMcpToolCallbackProvider { + + @Override + public ToolCallback[] getToolCallbacks() { + throw new RuntimeException("CustomMcpToolCallbackProvider#getToolCallbacks should not be called"); + } + + } + + } + +} diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java new file mode 100644 index 00000000000..596f9cb20c3 --- /dev/null +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-webflux/src/test/java/org/springframework/ai/mcp/server/autoconfigure/StreamableMcpAnnotationsWithLLMIT.java @@ -0,0 +1,339 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.mcp.server.autoconfigure; + +import java.time.Duration; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import io.modelcontextprotocol.server.McpSyncServer; +import io.modelcontextprotocol.server.transport.WebFluxStreamableServerTransportProvider; +import io.modelcontextprotocol.spec.McpSchema; +import io.modelcontextprotocol.spec.McpSchema.CreateMessageResult; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springaicommunity.mcp.annotation.McpElicitation; +import org.springaicommunity.mcp.annotation.McpLogging; +import org.springaicommunity.mcp.annotation.McpProgress; +import org.springaicommunity.mcp.annotation.McpSampling; +import org.springaicommunity.mcp.annotation.McpTool; +import org.springaicommunity.mcp.annotation.McpToolParam; +import org.springaicommunity.mcp.context.McpSyncRequestContext; +import org.springaicommunity.mcp.context.StructuredElicitResult; +import reactor.netty.DisposableServer; +import reactor.netty.http.server.HttpServer; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.mcp.client.common.autoconfigure.McpClientAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.McpToolCallbackAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientAnnotationScannerAutoConfiguration; +import org.springframework.ai.mcp.client.common.autoconfigure.annotations.McpClientSpecificationFactoryAutoConfiguration; +import org.springframework.ai.mcp.client.webflux.autoconfigure.StreamableHttpWebFluxTransportAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.McpServerAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.ToolCallbackConverterAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.annotations.McpServerAnnotationScannerAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.annotations.McpServerSpecificationFactoryAutoConfiguration; +import org.springframework.ai.mcp.server.common.autoconfigure.properties.McpServerProperties; +import org.springframework.ai.mcp.server.common.autoconfigure.properties.McpServerStreamableHttpProperties; +import org.springframework.ai.model.anthropic.autoconfigure.AnthropicChatAutoConfiguration; +import org.springframework.ai.model.chat.client.autoconfigure.ChatClientAutoConfiguration; +import org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration; +import org.springframework.ai.retry.autoconfigure.SpringAiRetryAutoConfiguration; +import org.springframework.ai.tool.ToolCallbackProvider; +import org.springframework.boot.autoconfigure.AutoConfigurations; +import org.springframework.boot.autoconfigure.web.client.RestClientAutoConfiguration; +import org.springframework.boot.autoconfigure.web.reactive.function.client.WebClientAutoConfiguration; +import org.springframework.boot.test.context.runner.ApplicationContextRunner; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.http.server.reactive.HttpHandler; +import org.springframework.http.server.reactive.ReactorHttpHandlerAdapter; +import org.springframework.test.util.TestSocketUtils; +import org.springframework.web.reactive.function.server.RouterFunction; +import org.springframework.web.reactive.function.server.RouterFunctions; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + * @author Daniel Garnier-Moiroux + */ +@EnabledIfEnvironmentVariable(named = "ANTHROPIC_API_KEY", matches = ".*") +public class StreamableMcpAnnotationsWithLLMIT { + + private final ApplicationContextRunner serverContextRunner = new ApplicationContextRunner() + .withPropertyValues("spring.ai.mcp.server.protocol=STREAMABLE") + .withConfiguration(AutoConfigurations.of(McpServerAutoConfiguration.class, + ToolCallbackConverterAutoConfiguration.class, McpServerStreamableHttpWebFluxAutoConfiguration.class, + McpServerAnnotationScannerAutoConfiguration.class, + McpServerSpecificationFactoryAutoConfiguration.class)); + + private final ApplicationContextRunner clientApplicationContext = new ApplicationContextRunner() + .withPropertyValues("spring.ai.anthropic.apiKey=" + System.getenv("ANTHROPIC_API_KEY")) + .withConfiguration(anthropicAutoConfig(McpToolCallbackAutoConfiguration.class, McpClientAutoConfiguration.class, + StreamableHttpWebFluxTransportAutoConfiguration.class, + McpClientAnnotationScannerAutoConfiguration.class, McpClientSpecificationFactoryAutoConfiguration.class, + AnthropicChatAutoConfiguration.class, ChatClientAutoConfiguration.class)); + + private static AutoConfigurations anthropicAutoConfig(Class... additional) { + Class[] dependencies = { SpringAiRetryAutoConfiguration.class, ToolCallingAutoConfiguration.class, + RestClientAutoConfiguration.class, WebClientAutoConfiguration.class }; + Class[] all = Stream.concat(Arrays.stream(dependencies), Arrays.stream(additional)).toArray(Class[]::new); + return AutoConfigurations.of(all); + } + + private static AtomicInteger toolCouter = new AtomicInteger(0); + + @Test + void clientServerCapabilities() { + + int serverPort = TestSocketUtils.findAvailableTcpPort(); + + this.serverContextRunner.withUserConfiguration(TestMcpServerConfiguration.class) + .withPropertyValues(// @formatter:off + "spring.ai.mcp.server.name=test-mcp-server", + "spring.ai.mcp.server.version=1.0.0", + "spring.ai.mcp.server.streamable-http.keep-alive-interval=1s", + "spring.ai.mcp.server.streamable-http.mcp-endpoint=/mcp") // @formatter:on + .run(serverContext -> { + // Verify all required beans are present + assertThat(serverContext).hasSingleBean(WebFluxStreamableServerTransportProvider.class); + assertThat(serverContext).hasSingleBean(RouterFunction.class); + assertThat(serverContext).hasSingleBean(McpSyncServer.class); + + // Verify server properties are configured correctly + McpServerProperties properties = serverContext.getBean(McpServerProperties.class); + assertThat(properties.getName()).isEqualTo("test-mcp-server"); + assertThat(properties.getVersion()).isEqualTo("1.0.0"); + + McpServerStreamableHttpProperties streamableHttpProperties = serverContext + .getBean(McpServerStreamableHttpProperties.class); + assertThat(streamableHttpProperties.getMcpEndpoint()).isEqualTo("/mcp"); + assertThat(streamableHttpProperties.getKeepAliveInterval()).isEqualTo(Duration.ofSeconds(1)); + + var httpServer = startHttpServer(serverContext, serverPort); + + this.clientApplicationContext.withUserConfiguration(TestMcpClientConfiguration.class) + .withUserConfiguration(TestMcpClientHandlers.class) + .withPropertyValues(// @formatter:off + "spring.ai.mcp.client.streamable-http.connections.server1.url=http://localhost:" + serverPort, + "spring.ai.mcp.client.initialized=false") // @formatter:on + .run(clientContext -> { + + ChatClient.Builder builder = clientContext.getBean(ChatClient.Builder.class); + + ToolCallbackProvider tcp = clientContext.getBean(ToolCallbackProvider.class); + + assertThat(builder).isNotNull(); + + ChatClient chatClient = builder.defaultToolCallbacks(tcp) + .defaultToolContext(Map.of("progressToken", "test-progress-token")) + .build(); + + String cResponse = chatClient.prompt() + .user("What is the weather in Amsterdam right now") + .call() + .content(); + + assertThat(cResponse).isNotEmpty(); + assertThat(cResponse).contains("22"); + + assertThat(toolCouter.get()).isEqualTo(1); + + // PROGRESS + TestMcpClientConfiguration.TestContext testContext = clientContext + .getBean(TestMcpClientConfiguration.TestContext.class); + assertThat(testContext.progressLatch.await(5, TimeUnit.SECONDS)) + .as("Should receive progress notifications in reasonable time") + .isTrue(); + assertThat(testContext.progressNotifications).hasSize(3); + + Map notificationMap = testContext.progressNotifications + .stream() + .collect(Collectors.toMap(n -> n.message(), n -> n)); + + // First notification should be 0.0/1.0 progress + assertThat(notificationMap.get("tool call start").progressToken()) + .isEqualTo("test-progress-token"); + assertThat(notificationMap.get("tool call start").progress()).isEqualTo(0.0); + assertThat(notificationMap.get("tool call start").total()).isEqualTo(1.0); + assertThat(notificationMap.get("tool call start").message()).isEqualTo("tool call start"); + + // Second notification should be 1.0/1.0 progress + assertThat(notificationMap.get("elicitation completed").progressToken()) + .isEqualTo("test-progress-token"); + assertThat(notificationMap.get("elicitation completed").progress()).isEqualTo(0.5); + assertThat(notificationMap.get("elicitation completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("elicitation completed").message()) + .isEqualTo("elicitation completed"); + + // Third notification should be 0.5/1.0 progress + assertThat(notificationMap.get("sampling completed").progressToken()) + .isEqualTo("test-progress-token"); + assertThat(notificationMap.get("sampling completed").progress()).isEqualTo(1.0); + assertThat(notificationMap.get("sampling completed").total()).isEqualTo(1.0); + assertThat(notificationMap.get("sampling completed").message()).isEqualTo("sampling completed"); + + }); + + stopHttpServer(httpServer); + }); + } + + // Helper methods to start and stop the HTTP server + private static DisposableServer startHttpServer(ApplicationContext serverContext, int port) { + WebFluxStreamableServerTransportProvider mcpStreamableServerTransport = serverContext + .getBean(WebFluxStreamableServerTransportProvider.class); + HttpHandler httpHandler = RouterFunctions.toHttpHandler(mcpStreamableServerTransport.getRouterFunction()); + ReactorHttpHandlerAdapter adapter = new ReactorHttpHandlerAdapter(httpHandler); + return HttpServer.create().port(port).handle(adapter).bindNow(); + } + + private static void stopHttpServer(DisposableServer server) { + if (server != null) { + server.disposeNow(); + } + } + + record ElicitInput(String message) { + } + + public static class TestMcpServerConfiguration { + + @Bean + public McpServerHandlers serverSideSpecProviders() { + return new McpServerHandlers(); + } + + public static class McpServerHandlers { + + @McpTool(description = "Provides weather information by city name") + public String weather(McpSyncRequestContext ctx, @McpToolParam String cityName) { + + toolCouter.incrementAndGet(); + + ctx.info("Weather called!"); + + ctx.progress(p -> p.progress(0.0).total(1.0).message("tool call start")); + + ctx.ping(); // call client ping + + // call elicitation + var elicitationResult = ctx.elicit(e -> e.message("Test message"), ElicitInput.class); + + ctx.progress(p -> p.progress(0.50).total(1.0).message("elicitation completed")); + + // call sampling + CreateMessageResult samplingResponse = ctx.sample(s -> s.message("Test Sampling Message") + .modelPreferences(pref -> pref.modelHints("OpenAi", "Ollama") + .costPriority(1.0) + .speedPriority(1.0) + .intelligencePriority(1.0))); + + ctx.progress(p -> p.progress(1.0).total(1.0).message("sampling completed")); + + ctx.info("Tool1 Done!"); + + return "Weahter is 22C with rain " + samplingResponse.toString() + ", " + elicitationResult.toString(); + } + + } + + } + + public static class TestMcpClientConfiguration { + + @Bean + public TestContext testContext() { + return new TestContext(); + } + + public static class TestContext { + + final AtomicReference loggingNotificationRef = new AtomicReference<>(); + + final CountDownLatch progressLatch = new CountDownLatch(3); + + final List progressNotifications = new CopyOnWriteArrayList<>(); + + } + + } + + public static class TestMcpClientHandlers { + + private static final Logger logger = LoggerFactory.getLogger(TestMcpClientHandlers.class); + + private final ChatClient client; + + private TestMcpClientConfiguration.TestContext testContext; + + public TestMcpClientHandlers(TestMcpClientConfiguration.TestContext testContext, + ChatClient.Builder clientBuilder) { + this.testContext = testContext; + this.client = clientBuilder.build(); + } + + @McpProgress(clients = "server1") + public void progressHandler(McpSchema.ProgressNotification progressNotification) { + logger.info("MCP PROGRESS: [{}] progress: {} total: {} message: {}", progressNotification.progressToken(), + progressNotification.progress(), progressNotification.total(), progressNotification.message()); + this.testContext.progressNotifications.add(progressNotification); + this.testContext.progressLatch.countDown(); + } + + @McpLogging(clients = "server1") + public void loggingHandler(McpSchema.LoggingMessageNotification loggingMessage) { + this.testContext.loggingNotificationRef.set(loggingMessage); + logger.info("MCP LOGGING: [{}] {}", loggingMessage.level(), loggingMessage.data()); + } + + @McpSampling(clients = "server1") + public McpSchema.CreateMessageResult samplingHandler(McpSchema.CreateMessageRequest llmRequest) { + logger.info("MCP SAMPLING: {}", llmRequest); + + String userPrompt = ((McpSchema.TextContent) llmRequest.messages().get(0).content()).text(); + String modelHint = llmRequest.modelPreferences().hints().get(0).name(); + // In a real use-case, we would use the chat client to call the LLM again + logger.info("MCP SAMPLING: simulating using chat client {}", this.client); + + return McpSchema.CreateMessageResult.builder() + .content(new McpSchema.TextContent("Response " + userPrompt + " with model hint " + modelHint)) + .build(); + } + + @McpElicitation(clients = "server1") + public StructuredElicitResult elicitationHandler(McpSchema.ElicitRequest request) { + logger.info("MCP ELICITATION: {}", request); + StreamableMcpAnnotationsWithLLMIT.ElicitInput elicitData = new StreamableMcpAnnotationsWithLLMIT.ElicitInput( + request.message()); + return StructuredElicitResult.builder().structuredContent(elicitData).build(); + } + + } + +} diff --git a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java index a7dfbc74f40..e5d6699cd69 100644 --- a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java +++ b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java @@ -17,6 +17,7 @@ package org.springframework.ai.model.tool.autoconfigure; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import io.micrometer.observation.ObservationRegistry; @@ -35,7 +36,14 @@ import org.springframework.ai.tool.resolution.SpringBeanToolCallbackResolver; import org.springframework.ai.tool.resolution.StaticToolCallbackResolver; import org.springframework.ai.tool.resolution.ToolCallbackResolver; +import org.springframework.beans.BeansException; import org.springframework.beans.factory.ObjectProvider; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.beans.factory.config.BeanDefinition; +import org.springframework.beans.factory.support.BeanDefinitionBuilder; +import org.springframework.beans.factory.support.BeanDefinitionRegistry; +import org.springframework.beans.factory.support.BeanDefinitionRegistryPostProcessor; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean; @@ -43,6 +51,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.context.support.GenericApplicationContext; +import org.springframework.core.ResolvableType; import org.springframework.util.ClassUtils; /** @@ -56,17 +65,31 @@ @AutoConfiguration @ConditionalOnClass(ChatModel.class) @EnableConfigurationProperties(ToolCallingProperties.class) -public class ToolCallingAutoConfiguration { +public class ToolCallingAutoConfiguration implements BeanDefinitionRegistryPostProcessor { private static final Logger logger = LoggerFactory.getLogger(ToolCallingAutoConfiguration.class); + // Marker qualifier to exclude MCP-related ToolCallbackProviders + private static final String EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER = "org.springframework.ai.model.tool.autoconfigure.ToolCallingAutoConfiguration.toolcallbackprovider.mcp-excluded"; + + /** + * The default {@link ToolCallbackResolver} resolves tools by name for methods, + * functions, and {@link ToolCallbackProvider} beans. + *

+ * MCP providers should not be injected to avoid cyclic dependencies. If some MCP + * providers are injected, we filter them out to avoid eagerly calling + * #getToolCallbacks. + */ @Bean @ConditionalOnMissingBean ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext, - List toolCallbacks, List tcbProviders) { - + List toolCallbacks, + @Qualifier(EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER) List tcbProviders) { List allFunctionAndToolCallbacks = new ArrayList<>(toolCallbacks); - tcbProviders.stream().map(pr -> List.of(pr.getToolCallbacks())).forEach(allFunctionAndToolCallbacks::addAll); + tcbProviders.stream() + .filter(pr -> !isMcpToolCallbackProvider(ResolvableType.forInstance(pr))) + .map(pr -> List.of(pr.getToolCallbacks())) + .forEach(allFunctionAndToolCallbacks::addAll); var staticToolCallbackResolver = new StaticToolCallbackResolver(allFunctionAndToolCallbacks); @@ -77,6 +100,50 @@ ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationC return new DelegatingToolCallbackResolver(List.of(staticToolCallbackResolver, springBeanToolCallbackResolver)); } + /** + * Wrap {@link ToolCallbackProvider} beans that are not MCP-related into a named bean, + * which will be picked up by the + * {@link ToolCallingAutoConfiguration#toolCallbackResolver}. + *

+ * MCP providers must be excluded, because they may depend on a {@code ChatClient} to + * do sampling. The chat client, in turn, depends on a {@link ToolCallbackResolver}. + * To do the detection, we depend on the exposed bean type. If a bean uses a factory + * method which returns a {@link ToolCallbackProvider}, which is an MCP provider under + * the hood, it will be included in the list. + */ + @Override + public void postProcessBeanDefinitionRegistry(BeanDefinitionRegistry registry) throws BeansException { + if (!(registry instanceof DefaultListableBeanFactory beanFactory)) { + return; + } + + var excludeMcpToolCallbackProviderBeanDefinition = BeanDefinitionBuilder + .genericBeanDefinition(List.class, () -> { + var providerNames = beanFactory.getBeanNamesForType(ToolCallbackProvider.class); + return Arrays.stream(providerNames) + .filter(name -> !isMcpToolCallbackProvider(beanFactory.getBeanDefinition(name).getResolvableType())) + .map(beanFactory::getBean) + .filter(ToolCallbackProvider.class::isInstance) + .map(ToolCallbackProvider.class::cast) + .toList(); + }) + .setScope(BeanDefinition.SCOPE_SINGLETON) + .setLazyInit(true) + .getBeanDefinition(); + + registry.registerBeanDefinition(EXCLUDE_MCP_TOOL_CALLBACK_PROVIDER, + excludeMcpToolCallbackProviderBeanDefinition); + } + + private static boolean isMcpToolCallbackProvider(ResolvableType type) { + if (type.getType().getTypeName().equals("org.springframework.ai.mcp.SyncMcpToolCallbackProvider") + || type.getType().getTypeName().equals("org.springframework.ai.mcp.AsyncMcpToolCallbackProvider")) { + return true; + } + var superType = type.getSuperType(); + return superType != ResolvableType.NONE && isMcpToolCallbackProvider(superType); + } + @Bean @ConditionalOnMissingBean ToolExecutionExceptionProcessor toolExecutionExceptionProcessor(ToolCallingProperties properties) {