diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/StatelessToolCallbackConverterAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/StatelessToolCallbackConverterAutoConfiguration.java index d439db37063..55fb3336d99 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/StatelessToolCallbackConverterAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/StatelessToolCallbackConverterAutoConfiguration.java @@ -51,9 +51,11 @@ public class StatelessToolCallbackConverterAutoConfiguration { matchIfMissing = true) public List syncTools( ObjectProvider> toolCalls, List toolCallbackList, - List toolCallbackProvider, McpServerProperties serverProperties) { + ObjectProvider> tcbProviderList, + ObjectProvider tcbProviders, McpServerProperties serverProperties) { - List tools = this.aggregateToolCallbacks(toolCalls, toolCallbackList, toolCallbackProvider); + List tools = this.aggregateToolCallbacks(toolCalls, toolCallbackList, tcbProviderList, + tcbProviders); return this.toSyncToolSpecifications(tools, serverProperties); } @@ -81,9 +83,11 @@ private List toSyncToolSpecifi @ConditionalOnProperty(prefix = McpServerProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") public List asyncTools( ObjectProvider> toolCalls, List toolCallbackList, - List toolCallbackProvider, McpServerProperties serverProperties) { + ObjectProvider> tcbProviderList, + ObjectProvider tcbProviders, McpServerProperties serverProperties) { - List tools = this.aggregateToolCallbacks(toolCalls, toolCallbackList, toolCallbackProvider); + List tools = this.aggregateToolCallbacks(toolCalls, toolCallbackList, tcbProviderList, + tcbProviders); return this.toAsyncToolSpecification(tools, serverProperties); } @@ -107,7 +111,16 @@ private List toAsyncToolSpeci } private List aggregateToolCallbacks(ObjectProvider> toolCalls, - List toolCallbacksList, List toolCallbackProvider) { + List toolCallbacksList, ObjectProvider> tcbProviderList, + ObjectProvider tcbProviders) { + + // Merge ToolCallbackProviders from both ObjectProviders. + List totalToolCallbackProviders = new ArrayList<>( + tcbProviderList.stream().flatMap(List::stream).toList()); + totalToolCallbackProviders.addAll(tcbProviders.stream().toList()); + + // De-duplicate ToolCallbackProviders + totalToolCallbackProviders = totalToolCallbackProviders.stream().distinct().toList(); List tools = new ArrayList<>(toolCalls.stream().flatMap(List::stream).toList()); @@ -115,7 +128,7 @@ private List aggregateToolCallbacks(ObjectProvider providerToolCallbacks = toolCallbackProvider.stream() + List providerToolCallbacks = totalToolCallbackProviders.stream() .map(pr -> List.of(pr.getToolCallbacks())) .flatMap(List::stream) .filter(fc -> fc instanceof ToolCallback) diff --git a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/ToolCallbackConverterAutoConfiguration.java b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/ToolCallbackConverterAutoConfiguration.java index 27dacd7912e..5c7197c80bc 100644 --- a/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/ToolCallbackConverterAutoConfiguration.java +++ b/auto-configurations/mcp/spring-ai-autoconfigure-mcp-server-common/src/main/java/org/springframework/ai/mcp/server/common/autoconfigure/ToolCallbackConverterAutoConfiguration.java @@ -49,10 +49,19 @@ public class ToolCallbackConverterAutoConfiguration { @ConditionalOnProperty(prefix = McpServerProperties.CONFIG_PREFIX, name = "type", havingValue = "SYNC", matchIfMissing = true) public List syncTools(ObjectProvider> toolCalls, - List toolCallbacksList, List toolCallbackProvider, - McpServerProperties serverProperties) { + List toolCallbacksList, ObjectProvider> tcbProviderList, + ObjectProvider tcbProviders, McpServerProperties serverProperties) { + + // Merge ToolCallbackProviders from both ObjectProviders. + List totalToolCallbackProviders = new ArrayList<>( + tcbProviderList.stream().flatMap(List::stream).toList()); + totalToolCallbackProviders.addAll(tcbProviders.stream().toList()); + + // De-duplicate ToolCallbackProviders + totalToolCallbackProviders = totalToolCallbackProviders.stream().distinct().toList(); - List tools = this.aggregateToolCallbacks(toolCalls, toolCallbacksList, toolCallbackProvider); + List tools = this.aggregateToolCallbacks(toolCalls, toolCallbacksList, tcbProviderList, + tcbProviders); return this.toSyncToolSpecifications(tools, serverProperties); } @@ -63,10 +72,7 @@ private List toSyncToolSpecifications(L // De-duplicate tools by their name, keeping the first occurrence of each tool // name return tools.stream() // Key: tool name - .collect(Collectors.toMap(tool -> tool.getToolDefinition().name(), tool -> tool, // Value: - // the - // tool - // itself + .collect(Collectors.toMap(tool -> tool.getToolDefinition().name(), tool -> tool, (existing, replacement) -> existing)) // On duplicate key, keep the // existing tool .values() @@ -83,10 +89,11 @@ private List toSyncToolSpecifications(L @Bean @ConditionalOnProperty(prefix = McpServerProperties.CONFIG_PREFIX, name = "type", havingValue = "ASYNC") public List asyncTools(ObjectProvider> toolCalls, - List toolCallbacksList, List toolCallbackProvider, - McpServerProperties serverProperties) { + List toolCallbacksList, ObjectProvider> tcbProviderList, + ObjectProvider tcbProviders, McpServerProperties serverProperties) { - List tools = this.aggregateToolCallbacks(toolCalls, toolCallbacksList, toolCallbackProvider); + List tools = this.aggregateToolCallbacks(toolCalls, toolCallbacksList, tcbProviderList, + tcbProviders); return this.toAsyncToolSpecification(tools, serverProperties); } @@ -114,7 +121,16 @@ private List toAsyncToolSpecification( } private List aggregateToolCallbacks(ObjectProvider> toolCalls, - List toolCallbacksList, List toolCallbackProvider) { + List toolCallbacksList, ObjectProvider> tcbProviderList, + ObjectProvider tcbProviders) { + + // Merge ToolCallbackProviders from both ObjectProviders. + List totalToolCallbackProviders = new ArrayList<>( + tcbProviderList.stream().flatMap(List::stream).toList()); + totalToolCallbackProviders.addAll(tcbProviders.stream().toList()); + + // De-duplicate ToolCallbackProviders + totalToolCallbackProviders = totalToolCallbackProviders.stream().distinct().toList(); List tools = new ArrayList<>(toolCalls.stream().flatMap(List::stream).toList()); @@ -122,7 +138,7 @@ private List aggregateToolCallbacks(ObjectProvider providerToolCallbacks = toolCallbackProvider.stream() + List providerToolCallbacks = totalToolCallbackProviders.stream() .map(pr -> List.of(pr.getToolCallbacks())) .flatMap(List::stream) .filter(fc -> fc instanceof ToolCallback) diff --git a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java index bcdad6b2bf5..2b5ef331482 100644 --- a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java +++ b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/main/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfiguration.java @@ -69,10 +69,24 @@ public class ToolCallingAutoConfiguration { */ @Bean @ConditionalOnMissingBean - ToolCallbackResolver toolCallbackResolver(GenericApplicationContext applicationContext, - List toolCallbacks, List tcbProviders) { + ToolCallbackResolver toolCallbackResolver( + GenericApplicationContext applicationContext, // @formatter:off + List toolCallbacks, + // Deprecated in favor of the tcbProviders. Kept for backward compatibility. + ObjectProvider> tcbProviderList, + ObjectProvider tcbProviders) { // @formatter:on + List allFunctionAndToolCallbacks = new ArrayList<>(toolCallbacks); - tcbProviders.stream() + + // Merge ToolCallbackProviders from both ObjectProviders. + List totalToolCallbackProviders = new ArrayList<>( + tcbProviderList.stream().flatMap(List::stream).toList()); + totalToolCallbackProviders.addAll(tcbProviders.stream().toList()); + + // De-duplicate ToolCallbackProviders + totalToolCallbackProviders = totalToolCallbackProviders.stream().distinct().toList(); + + totalToolCallbackProviders.stream() .filter(pr -> !isMcpToolCallbackProvider(ResolvableType.forInstance(pr))) .map(pr -> List.of(pr.getToolCallbacks())) .forEach(allFunctionAndToolCallbacks::addAll); diff --git a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java index 64a080489ee..709a7d1d155 100644 --- a/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java +++ b/auto-configurations/models/tool/spring-ai-autoconfigure-model-tool/src/test/java/org/springframework/ai/model/tool/autoconfigure/ToolCallingAutoConfigurationTests.java @@ -16,14 +16,20 @@ package org.springframework.ai.model.tool.autoconfigure; +import java.util.List; import java.util.function.Function; +import io.micrometer.observation.ObservationRegistry; import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.mcp.AsyncMcpToolCallbackProvider; import org.springframework.ai.mcp.SyncMcpToolCallbackProvider; import org.springframework.ai.model.tool.DefaultToolCallingManager; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.model.tool.ToolCallingManager; +import org.springframework.ai.model.tool.ToolExecutionResult; import org.springframework.ai.tool.StaticToolCallbackProvider; import org.springframework.ai.tool.ToolCallback; import org.springframework.ai.tool.ToolCallbackProvider; @@ -36,6 +42,7 @@ import org.springframework.ai.tool.method.MethodToolCallback; import org.springframework.ai.tool.method.MethodToolCallbackProvider; import org.springframework.ai.tool.observation.ToolCallingContentObservationFilter; +import org.springframework.ai.tool.observation.ToolCallingObservationConvention; import org.springframework.ai.tool.resolution.DelegatingToolCallbackResolver; import org.springframework.ai.tool.resolution.ToolCallbackResolver; import org.springframework.ai.tool.support.ToolDefinitions; @@ -213,6 +220,285 @@ void toolCallbackResolverDoesNotUseMcpToolCallbackProviders() { }); } + @Test + void customToolCallbackResolverOverridesDefault() { + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) + .withUserConfiguration(CustomToolCallbackResolverConfig.class) + .run(context -> { + assertThat(context).hasBean("toolCallbackResolver"); + assertThat(context.getBean("toolCallbackResolver")).isInstanceOf(CustomToolCallbackResolver.class); + }); + } + + @Test + void customToolExecutionExceptionProcessorOverridesDefault() { + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) + .withUserConfiguration(CustomToolExecutionExceptionProcessorConfig.class) + .run(context -> { + assertThat(context).hasBean("toolExecutionExceptionProcessor"); + assertThat(context.getBean("toolExecutionExceptionProcessor")) + .isInstanceOf(CustomToolExecutionExceptionProcessor.class); + }); + } + + @Test + void customToolCallingManagerOverridesDefault() { + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) + .withUserConfiguration(CustomToolCallingManagerConfig.class) + .run(context -> { + assertThat(context).hasBean("toolCallingManager"); + assertThat(context.getBean("toolCallingManager")).isInstanceOf(CustomToolCallingManager.class); + }); + } + + @Test + void observationContentFilterNotCreatedWhenPropertyDisabled() { + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) + .withPropertyValues("spring.ai.tools.observations.include-content=false") + .run(context -> { + assertThat(context).doesNotHaveBean("toolCallingContentObservationFilter"); + assertThat(context).doesNotHaveBean(ToolCallingContentObservationFilter.class); + }); + } + + @Test + void toolCallbackResolverResolvesToolCallbacksFromBeans() { + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) + .withUserConfiguration(ToolCallbackBeansConfig.class) + .run(context -> { + var resolver = context.getBean(ToolCallbackResolver.class); + + assertThat(resolver.resolve("getWeather")).isNotNull(); + assertThat(resolver.resolve("getWeather").getToolDefinition().name()).isEqualTo("getWeather"); + + assertThat(resolver.resolve("weatherFunction")).isNotNull(); + assertThat(resolver.resolve("weatherFunction").getToolDefinition().name()).isEqualTo("weatherFunction"); + + assertThat(resolver.resolve("nonExistentTool")).isNull(); + }); + } + + @Test + void toolCallbackResolverResolvesMethodToolCallbacks() { + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) + .withUserConfiguration(MethodToolCallbackConfig.class) + .run(context -> { + var resolver = context.getBean(ToolCallbackResolver.class); + + assertThat(resolver.resolve("getForecastMethod")).isNotNull(); + assertThat(resolver.resolve("getForecastMethod").getToolDefinition().name()) + .isEqualTo("getForecastMethod"); + }); + } + + @Test + void toolCallingManagerIntegrationWithCustomComponents() { + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) + .withUserConfiguration(CustomObservationConfig.class) + .run(context -> { + assertThat(context).hasBean("toolCallingManager"); + assertThat(context).hasBean("customObservationRegistry"); + assertThat(context).hasBean("customObservationConvention"); + + var manager = context.getBean(ToolCallingManager.class); + assertThat(manager).isNotNull(); + }); + } + + @Test + void toolCallbackProviderBeansAreResolved() { + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) + .withUserConfiguration(ToolCallbackProviderConfig.class) + .run(context -> { + var resolver = context.getBean(ToolCallbackResolver.class); + + // Should resolve tools from the ToolCallbackProvider + assertThat(resolver.resolve("providerTool")).isNotNull(); + assertThat(resolver.resolve("providerTool").getToolDefinition().name()).isEqualTo("providerTool"); + }); + } + + @Test + void multipleToolCallbackProvidersAreResolved() { + new ApplicationContextRunner().withConfiguration(AutoConfigurations.of(ToolCallingAutoConfiguration.class)) + .withUserConfiguration(MultipleToolCallbackProvidersConfig.class) + .run(context -> { + var resolver = context.getBean(ToolCallbackResolver.class); + + // Should resolve tools from both providers + assertThat(resolver.resolve("tool1")).isNotNull(); + assertThat(resolver.resolve("tool2")).isNotNull(); + assertThat(resolver.resolve("tool3")).isNotNull(); + }); + } + + @Configuration + static class CustomToolCallbackResolverConfig { + + @Bean + public ToolCallbackResolver toolCallbackResolver() { + return new CustomToolCallbackResolver(); + } + + } + + static class CustomToolCallbackResolver implements ToolCallbackResolver { + + @Override + public ToolCallback resolve(String toolName) { + return null; + } + + } + + @Configuration + static class CustomToolExecutionExceptionProcessorConfig { + + @Bean + public ToolExecutionExceptionProcessor toolExecutionExceptionProcessor() { + return new CustomToolExecutionExceptionProcessor(); + } + + } + + static class CustomToolExecutionExceptionProcessor implements ToolExecutionExceptionProcessor { + + @Override + public String process(ToolExecutionException exception) { + return "Custom error handling"; + } + + } + + @Configuration + static class CustomToolCallingManagerConfig { + + @Bean + public ToolCallingManager toolCallingManager(ToolCallbackResolver resolver, + ToolExecutionExceptionProcessor processor) { + return new CustomToolCallingManager(); + } + + } + + static class CustomToolCallingManager implements ToolCallingManager { + + @Override + public List resolveToolDefinitions(ToolCallingChatOptions options) { + return List.of(); + } + + @Override + public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse) { + return null; + } + + } + + @Configuration + static class ToolCallbackBeansConfig { + + @Bean + public ToolCallback getWeather() { + return FunctionToolCallback.builder("getWeather", (Request request) -> "Sunny, 25°C") + .description("Gets the current weather") + .inputType(Request.class) + .build(); + } + + @Bean + @Description("Get weather forecast") + public Function weatherFunction() { + return request -> new Response("Sunny"); + } + + } + + @Configuration + static class MethodToolCallbackConfig { + + @Bean + public ToolCallbackProvider methodToolCallbacks() { + return MethodToolCallbackProvider.builder().toolObjects(new WeatherServiceForMethod()).build(); + } + + } + + static class WeatherServiceForMethod { + + @Tool(description = "Get the weather forecast") + public String getForecastMethod(String location) { + return "Sunny, 25°C"; + } + + } + + @Configuration + static class CustomObservationConfig { + + @Bean + public ObservationRegistry customObservationRegistry() { + return ObservationRegistry.create(); + } + + @Bean + public ToolCallingObservationConvention customObservationConvention() { + return new ToolCallingObservationConvention() { + }; + } + + } + + @Configuration + static class ToolCallbackProviderConfig { + + @Bean + public ToolCallbackProvider toolCallbackProvider() { + return () -> new ToolCallback[] { + FunctionToolCallback.builder("providerTool", (Request request) -> "Result") + .description("Tool from provider") + .inputType(Request.class) + .build() }; + } + + } + + @Configuration + static class MultipleToolCallbackProvidersConfig { + + @Bean + public ToolCallbackProvider toolCallbackProvider1() { + return () -> new ToolCallback[] { FunctionToolCallback.builder("tool1", (Request request) -> "Result1") + .description("Tool 1") + .inputType(Request.class) + .build() }; + } + + @Bean + public ToolCallbackProvider toolCallbackProvider2() { + return () -> new ToolCallback[] { FunctionToolCallback.builder("tool2", (Request request) -> "Result2") + .description("Tool 2") + .inputType(Request.class) + .build() }; + } + + @Bean + public List toolCallbackProviderList() { + return List + .of(() -> new ToolCallback[] { FunctionToolCallback.builder("tool3", (Request request) -> "Result3") + .description("Tool 3") + .inputType(Request.class) + .build() }); + } + + } + + public record Request(String location) { + } + + public record Response(String temperature) { + } + static class WeatherService { @Tool(description = "Get the weather in location. Return temperature in 36°F or 36°C format.") @@ -309,12 +595,6 @@ public AsyncMcpToolCallbackProvider asyncMcpToolCallbackProvider() { return provider; } - public record Request(String location) { - } - - public record Response(String temperature) { - } - } }