diff --git a/spring-ai-core/src/main/java/org/springframework/ai/aot/ToolBeanRegistrationAotProcessor.java b/spring-ai-core/src/main/java/org/springframework/ai/aot/ToolBeanRegistrationAotProcessor.java new file mode 100644 index 00000000000..f13aa0a6aef --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/aot/ToolBeanRegistrationAotProcessor.java @@ -0,0 +1,79 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.aot; + +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.hint.MemberCategory; +import org.springframework.aot.hint.ReflectionHints; +import org.springframework.beans.factory.aot.BeanRegistrationAotContribution; +import org.springframework.beans.factory.aot.BeanRegistrationAotProcessor; +import org.springframework.beans.factory.aot.BeanRegistrationCode; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.core.annotation.MergedAnnotations; +import org.springframework.lang.Nullable; +import org.springframework.util.ReflectionUtils; + +import java.util.stream.Stream; + +import static org.springframework.core.annotation.MergedAnnotations.SearchStrategy.TYPE_HIERARCHY; + +/** + * AOT {@code BeanRegistrationAotProcessor} that detects the presence of the {@link Tool} + * annotation on methods and creates the required reflection hints. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +class ToolBeanRegistrationAotProcessor implements BeanRegistrationAotProcessor { + + @Override + @Nullable + public BeanRegistrationAotContribution processAheadOfTime(RegisteredBean registeredBean) { + Class beanClass = registeredBean.getBeanClass(); + MergedAnnotations.Search search = MergedAnnotations.search(TYPE_HIERARCHY); + + boolean hasAnyToolAnnotatedMethods = Stream.of(ReflectionUtils.getDeclaredMethods(beanClass)) + .anyMatch(method -> search.from(method).isPresent(Tool.class)); + + if (hasAnyToolAnnotatedMethods) { + return new AotContribution(beanClass); + } + + return null; + } + + private static class AotContribution implements BeanRegistrationAotContribution { + + private final MemberCategory[] memberCategories = new MemberCategory[] { MemberCategory.INVOKE_DECLARED_METHODS, + MemberCategory.INVOKE_PUBLIC_METHODS }; + + private final Class toolClass; + + public AotContribution(Class toolClass) { + this.toolClass = toolClass; + } + + @Override + public void applyTo(GenerationContext generationContext, BeanRegistrationCode beanRegistrationCode) { + ReflectionHints reflectionHints = generationContext.getRuntimeHints().reflection(); + reflectionHints.registerType(toolClass, memberCategories); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/aot/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/aot/package-info.java new file mode 100644 index 00000000000..ebfb021c996 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/aot/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.aot; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java index c4ddfbb219e..7ae2a17ad49 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.util.Assert; /** @@ -41,7 +42,9 @@ * @param the 3rd party service input type. * @param the 3rd party service output type. * @author Christian Tzolov + * @deprecated in favor of {@link FunctionToolCallback}. */ +@Deprecated abstract class AbstractFunctionCallback implements BiFunction, FunctionCallback { private final String name; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultCommonCallbackInvokingSpec.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultCommonCallbackInvokingSpec.java index dd55f42064f..e84014c23b7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultCommonCallbackInvokingSpec.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultCommonCallbackInvokingSpec.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,9 +26,16 @@ import org.springframework.ai.model.function.FunctionCallback.CommonCallbackInvokingSpec; import org.springframework.ai.model.function.FunctionCallback.SchemaType; +import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.ai.tool.method.MethodToolCallback; import org.springframework.ai.util.JacksonUtils; import org.springframework.util.Assert; +/** + * @deprecated Use specific builder for the type of tool you need, e.g. + * {@link FunctionToolCallback.Builder} and {@link MethodToolCallback.Builder}. + */ +@Deprecated public class DefaultCommonCallbackInvokingSpec> implements CommonCallbackInvokingSpec { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilder.java index 3cb6fc11764..d18e0bf5dae 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallbackBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,6 +31,8 @@ import org.springframework.ai.model.function.FunctionCallback.FunctionInvokingSpec; import org.springframework.ai.model.function.FunctionCallback.MethodInvokingSpec; import org.springframework.ai.model.function.FunctionCallback.SchemaType; +import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.ai.tool.method.MethodToolCallback; import org.springframework.ai.util.ParsingUtils; import org.springframework.core.ParameterizedTypeReference; import org.springframework.util.Assert; @@ -42,7 +44,10 @@ * * @author Christian Tzolov * @since 1.0.0 + * @deprecated Use specific builder for the type of tool you need, e.g. + * {@link FunctionToolCallback.Builder} and {@link MethodToolCallback.Builder}. */ +@Deprecated public class DefaultFunctionCallbackBuilder implements FunctionCallback.Builder { private static final Logger logger = LoggerFactory.getLogger(DefaultFunctionCallbackBuilder.class); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptions.java index 14ccccf1050..9d1456f29a8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -25,6 +25,7 @@ import java.util.Set; import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.tool.DefaultToolCallingChatOptions; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; @@ -35,7 +36,9 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @deprecated in favor of {@link DefaultToolCallingChatOptions}. */ +@Deprecated public class DefaultFunctionCallingOptions implements FunctionCallingOptions { private List functionCallbacks = new ArrayList<>(); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilder.java index 15f8926e78a..85b58101703 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/DefaultFunctionCallingOptionsBuilder.java @@ -1,5 +1,5 @@ /* - * Copyright 2024-2024 the original author or authors. + * Copyright 2024-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,6 +22,7 @@ import java.util.Map; import java.util.Set; +import org.springframework.ai.model.tool.DefaultToolCallingChatOptions; import org.springframework.util.Assert; /** @@ -30,7 +31,9 @@ * @author Christian Tzolov * @author Thomas Vitale * @author Ilayaperumal Gopinathan + * @deprecated in favor of {@link DefaultToolCallingChatOptions.Builder}. */ +@Deprecated public class DefaultFunctionCallingOptionsBuilder implements FunctionCallingOptions.Builder { private final DefaultFunctionCallingOptions options; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java index 06ac88fda3b..225752d6d3d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,9 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.ai.tool.method.MethodToolCallback; import org.springframework.core.ParameterizedTypeReference; /** @@ -31,7 +34,9 @@ * Models and called on prompts that trigger the function call. * * @author Christian Tzolov + * @deprecated in favor of {@link ToolCallback}. */ +@Deprecated public interface FunctionCallback { /** @@ -115,7 +120,11 @@ enum SchemaType { *
  • {@link FunctionInvokingSpec} - The function invoking builder interface. *
  • {@link MethodInvokingSpec} - The method invoking builder interface. * + * + * @deprecated Use specific builder for the type of tool you need, e.g. + * {@link FunctionToolCallback.Builder} and {@link MethodToolCallback.Builder}. */ + @Deprecated interface Builder { /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java index ba8529e684f..1e886f6cb8b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallingOptions.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,6 +21,7 @@ import java.util.Set; import org.springframework.ai.chat.prompt.ChatOptions; +import org.springframework.ai.model.tool.ToolCallingChatOptions; /** * FunctionCallingOptions is a set of options that can be used to configure the function @@ -28,7 +29,9 @@ * * @author Christian Tzolov * @author Ilayaperumal Gopinathan + * @deprecated in favor of {@link ToolCallingChatOptions}. */ +@Deprecated public interface FunctionCallingOptions extends ChatOptions { /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionInvokingFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionInvokingFunctionCallback.java index a03c7b1b7aa..0fed96ab0ce 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionInvokingFunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionInvokingFunctionCallback.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.function.FunctionToolCallback; import org.springframework.util.Assert; /** @@ -34,7 +35,9 @@ * @param the input type * @param the output type * @author Christian Tzolov + * @deprecated in favor of {@link FunctionToolCallback}. */ +@Deprecated public final class FunctionInvokingFunctionCallback extends AbstractFunctionCallback { private final BiFunction biFunction; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/MethodInvokingFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/MethodInvokingFunctionCallback.java index 9e23370b4f8..e3a724b1548 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/MethodInvokingFunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/MethodInvokingFunctionCallback.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,6 +34,7 @@ import org.springframework.ai.chat.model.ToolContext; import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.tool.method.MethodToolCallback; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.CollectionUtils; @@ -51,7 +52,9 @@ * * @author Christian Tzolov * @since 1.0.0 + * @deprecated in favor of {@link MethodToolCallback}. */ +@Deprecated public class MethodInvokingFunctionCallback implements FunctionCallback { private static final Logger logger = LoggerFactory.getLogger(MethodInvokingFunctionCallback.class); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallback.java index f26b3d06eee..7ca2435a6f6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/ToolCallback.java @@ -41,18 +41,21 @@ default ToolMetadata getToolMetadata() { } @Override + @Deprecated // Call getToolDefinition().name() instead default String getName() { return getToolDefinition().name(); } @Override + @Deprecated // Call getToolDefinition().description() instead default String getDescription() { return getToolDefinition().description(); } @Override + @Deprecated // Call getToolDefinition().inputTypeSchema() instead default String getInputTypeSchema() { - return getToolDefinition().inputTypeSchema(); + return getToolDefinition().inputSchema(); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/DefaultToolDefinition.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/DefaultToolDefinition.java index 83c9327bf12..3c2e9056249 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/DefaultToolDefinition.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/DefaultToolDefinition.java @@ -16,7 +16,9 @@ package org.springframework.ai.tool.definition; +import org.springframework.ai.tool.util.ToolUtils; import org.springframework.util.Assert; +import org.springframework.util.StringUtils; /** * Default implementation of {@link ToolDefinition}. @@ -24,12 +26,12 @@ * @author Thomas Vitale * @since 1.0.0 */ -public record DefaultToolDefinition(String name, String description, String inputTypeSchema) implements ToolDefinition { +public record DefaultToolDefinition(String name, String description, String inputSchema) implements ToolDefinition { public DefaultToolDefinition { Assert.hasText(name, "name cannot be null or empty"); Assert.hasText(description, "description cannot be null or empty"); - Assert.hasText(inputTypeSchema, "inputTypeSchema cannot be null or empty"); + Assert.hasText(inputSchema, "inputSchema cannot be null or empty"); } public static Builder builder() { @@ -42,7 +44,7 @@ public static class Builder { private String description; - private String inputTypeSchema; + private String inputSchema; private Builder() { } @@ -57,13 +59,16 @@ public Builder description(String description) { return this; } - public Builder inputTypeSchema(String inputTypeSchema) { - this.inputTypeSchema = inputTypeSchema; + public Builder inputSchema(String inputSchema) { + this.inputSchema = inputSchema; return this; } public DefaultToolDefinition build() { - return new DefaultToolDefinition(name, description, inputTypeSchema); + if (!StringUtils.hasText(description)) { + description = ToolUtils.getToolDescriptionFromName(description); + } + return new DefaultToolDefinition(name, description, inputSchema); } } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java index dcc42297482..d25367c5bfe 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/definition/ToolDefinition.java @@ -40,9 +40,9 @@ public interface ToolDefinition { String description(); /** - * The JSON Schema of the parameters used to call the tool. + * The schema of the parameters used to call the tool. */ - String inputTypeSchema(); + String inputSchema(); /** * Create a default {@link ToolDefinition} builder. @@ -58,7 +58,7 @@ static ToolDefinition from(Method method) { return DefaultToolDefinition.builder() .name(ToolUtils.getToolName(method)) .description(ToolUtils.getToolDescription(method)) - .inputTypeSchema(JsonSchemaGenerator.generateForMethodInput(method)) + .inputSchema(JsonSchemaGenerator.generateForMethodInput(method)) .build(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java index 8eebbfa7ae1..a43dcd237c3 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverter.java @@ -16,9 +16,12 @@ package org.springframework.ai.tool.execution; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.ai.util.json.JsonParser; import org.springframework.lang.Nullable; -import org.springframework.util.Assert; + +import java.lang.reflect.Type; /** * A default implementation of {@link ToolCallResultConverter}. @@ -28,10 +31,12 @@ */ public final class DefaultToolCallResultConverter implements ToolCallResultConverter { + private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallResultConverter.class); + @Override - public String apply(@Nullable Object result, Class returnType) { - Assert.notNull(returnType, "returnType cannot be null"); + public String apply(@Nullable Object result, @Nullable Type returnType) { if (returnType == Void.TYPE) { + logger.debug("The tool has no return type. Converting to conventional response."); return "Done"; } else { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolCallResultConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolCallResultConverter.java index cb08edf768b..d302068187e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolCallResultConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/execution/ToolCallResultConverter.java @@ -18,6 +18,7 @@ import org.springframework.lang.Nullable; +import java.lang.reflect.Type; import java.util.function.BiFunction; /** @@ -28,12 +29,12 @@ * @since 1.0.0 */ @FunctionalInterface -public interface ToolCallResultConverter extends BiFunction, String> { +public interface ToolCallResultConverter extends BiFunction { /** * Given an Object returned by a tool, convert it to a String compatible with the * given class type. */ - String apply(@Nullable Object result, Class returnType); + String apply(@Nullable Object result, @Nullable Type returnType); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java new file mode 100644 index 00000000000..e708efa5304 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/function/FunctionToolCallback.java @@ -0,0 +1,212 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.function; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.DefaultToolCallResultConverter; +import org.springframework.ai.tool.execution.ToolCallResultConverter; +import org.springframework.ai.tool.metadata.ToolMetadata; +import org.springframework.ai.tool.util.ToolUtils; +import org.springframework.ai.util.json.JsonParser; +import org.springframework.ai.util.json.JsonSchemaGenerator; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +import java.lang.reflect.Type; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +/** + * A {@link ToolCallback} implementation to invoke functions as tools. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public class FunctionToolCallback implements ToolCallback { + + private static final Logger logger = LoggerFactory.getLogger(FunctionToolCallback.class); + + private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter(); + + private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build(); + + private final ToolDefinition toolDefinition; + + private final ToolMetadata toolMetadata; + + private final Type toolInputType; + + private final BiFunction toolFunction; + + private final ToolCallResultConverter toolCallResultConverter; + + public FunctionToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Type toolInputType, + BiFunction toolFunction, @Nullable ToolCallResultConverter toolCallResultConverter) { + Assert.notNull(toolDefinition, "toolDefinition cannot be null"); + Assert.notNull(toolInputType, "toolInputType cannot be null"); + Assert.notNull(toolFunction, "toolFunction cannot be null"); + this.toolDefinition = toolDefinition; + this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA; + this.toolFunction = toolFunction; + this.toolInputType = toolInputType; + this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter + : DEFAULT_RESULT_CONVERTER; + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public ToolMetadata getToolMetadata() { + return toolMetadata; + } + + @Override + public String call(String toolInput) { + return call(toolInput, null); + } + + @Override + public String call(String toolInput, @Nullable ToolContext toolContext) { + Assert.hasText(toolInput, "toolInput cannot be null or empty"); + + logger.debug("Starting execution of tool: {}", toolDefinition.name()); + + I request = JsonParser.fromJson(toolInput, toolInputType); + O response = toolFunction.apply(request, toolContext); + + logger.debug("Successful execution of tool: {}", toolDefinition.name()); + + return toolCallResultConverter.apply(response, null); + } + + /** + * Build a {@link FunctionToolCallback} from a {@link BiFunction}. + */ + public static Builder builder(String name, BiFunction function) { + return new Builder<>(name, function); + } + + /** + * Build a {@link FunctionToolCallback} from a {@link Function}. + */ + public static Builder builder(String name, Function function) { + Assert.notNull(function, "function cannot be null"); + return new Builder<>(name, (request, context) -> function.apply(request)); + } + + /** + * Build a {@link FunctionToolCallback} from a {@link Supplier}. + */ + public static Builder builder(String name, Supplier supplier) { + Assert.notNull(supplier, "supplier cannot be null"); + Function function = input -> supplier.get(); + return builder(name, function).inputType(Void.class); + } + + /** + * Build a {@link FunctionToolCallback} from a {@link Consumer}. + */ + public static Builder builder(String name, Consumer consumer) { + Assert.notNull(consumer, "consumer cannot be null"); + Function function = (I input) -> { + consumer.accept(input); + return null; + }; + return builder(name, function); + } + + public static class Builder { + + private String name; + + private String description; + + private String inputSchema; + + private Type inputType; + + private ToolMetadata toolMetadata; + + private BiFunction toolFunction; + + private ToolCallResultConverter toolCallResultConverter; + + private Builder(String name, BiFunction toolFunction) { + Assert.hasText(name, "name cannot be null or empty"); + Assert.notNull(toolFunction, "toolFunction cannot be null"); + this.name = name; + this.toolFunction = toolFunction; + } + + public Builder description(String description) { + this.description = description; + return this; + } + + public Builder inputSchema(String inputSchema) { + this.inputSchema = inputSchema; + return this; + } + + public Builder inputType(Type inputType) { + this.inputType = inputType; + return this; + } + + public Builder inputType(ParameterizedTypeReference inputType) { + Assert.notNull(inputType, "inputType cannot be null"); + this.inputType = inputType.getType(); + return this; + } + + public Builder toolMetadata(ToolMetadata toolMetadata) { + this.toolMetadata = toolMetadata; + return this; + } + + public Builder toolCallResultConverter(ToolCallResultConverter toolCallResultConverter) { + this.toolCallResultConverter = toolCallResultConverter; + return this; + } + + public FunctionToolCallback build() { + Assert.notNull(inputType, "inputType cannot be null"); + var toolDefinition = ToolDefinition.builder() + .name(name) + .description( + StringUtils.hasText(description) ? description : ToolUtils.getToolDescriptionFromName(name)) + .inputSchema( + StringUtils.hasText(inputSchema) ? inputSchema : JsonSchemaGenerator.generateForType(inputType)) + .build(); + return new FunctionToolCallback<>(toolDefinition, toolMetadata, inputType, toolFunction, + toolCallResultConverter); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/function/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/function/package-info.java new file mode 100644 index 00000000000..76bce773801 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/function/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +@NonNullApi +@NonNullFields +package org.springframework.ai.tool.function; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java index 424105671d9..70137e2d37d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java @@ -35,6 +35,7 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Modifier; +import java.lang.reflect.Type; import java.util.Map; import java.util.stream.Stream; @@ -50,6 +51,8 @@ public class MethodToolCallback implements ToolCallback { private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter(); + private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build(); + private final ToolDefinition toolDefinition; private final ToolMetadata toolMetadata; @@ -60,14 +63,13 @@ public class MethodToolCallback implements ToolCallback { private final ToolCallResultConverter toolCallResultConverter; - public MethodToolCallback(ToolDefinition toolDefinition, ToolMetadata toolMetadata, Method toolMethod, + public MethodToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Method toolMethod, Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) { Assert.notNull(toolDefinition, "toolDefinition cannot be null"); - Assert.notNull(toolMetadata, "toolMetadata cannot be null"); Assert.notNull(toolMethod, "toolMethod cannot be null"); Assert.notNull(toolObject, "toolObject cannot be null"); this.toolDefinition = toolDefinition; - this.toolMetadata = toolMetadata; + this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA; this.toolMethod = toolMethod; this.toolObject = toolObject; this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter @@ -105,7 +107,7 @@ public String call(String toolInput, @Nullable ToolContext toolContext) { logger.debug("Successful execution of tool: {}", toolDefinition.name()); - Class returnType = toolMethod.getReturnType(); + Type returnType = toolMethod.getGenericReturnType(); return toolCallResultConverter.apply(result, returnType); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java index f9aebeef205..786d97e92f6 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/util/ToolUtils.java @@ -21,6 +21,8 @@ import org.springframework.ai.tool.execution.DefaultToolCallResultConverter; import org.springframework.ai.tool.execution.ToolCallResultConverter; import org.springframework.ai.util.ParsingUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; import org.springframework.util.StringUtils; import java.lang.reflect.Method; @@ -47,6 +49,11 @@ public static String getToolName(Method method) { return StringUtils.hasText(tool.name()) ? tool.name() : method.getName(); } + public static String getToolDescriptionFromName(@Nullable String toolName) { + Assert.hasText(toolName, "toolName cannot be null or empty"); + return ParsingUtils.reConcatenateCamelCase(toolName, " "); + } + public static String getToolDescription(Method method) { var tool = method.getAnnotation(Tool.class); if (tool == null) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/util/json/JsonParser.java b/spring-ai-core/src/main/java/org/springframework/ai/util/json/JsonParser.java index cf0acb2d3c5..b4d184de659 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/util/json/JsonParser.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/util/json/JsonParser.java @@ -27,6 +27,8 @@ import org.springframework.util.Assert; import org.springframework.util.ClassUtils; +import java.lang.reflect.Type; + /** * Utilities to perform parsing operations between JSON and Java. */ @@ -64,6 +66,21 @@ public static T fromJson(String json, Class type) { } } + /** + * Converts a JSON string to a Java object. + */ + public static T fromJson(String json, Type type) { + Assert.notNull(json, "json cannot be null"); + Assert.notNull(type, "type cannot be null"); + + try { + return OBJECT_MAPPER.readValue(json, OBJECT_MAPPER.constructType(type)); + } + catch (JsonProcessingException ex) { + throw new IllegalStateException("Conversion from JSON to %s failed".formatted(type.getTypeName()), ex); + } + } + /** * Converts a JSON string to a Java object. */ diff --git a/spring-ai-core/src/main/java/org/springframework/ai/util/json/JsonSchemaGenerator.java b/spring-ai-core/src/main/java/org/springframework/ai/util/json/JsonSchemaGenerator.java index 6196de3b05f..6e6bb0bbb13 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/util/json/JsonSchemaGenerator.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/util/json/JsonSchemaGenerator.java @@ -112,6 +112,9 @@ public static String generateForMethodInput(Method method, SchemaOption... schem public static String generateForType(Type type, SchemaOption... schemaOptions) { Assert.notNull(type, "type cannot be null"); ObjectNode schema = TYPE_SCHEMA_GENERATOR.generateSchema(type); + if ((type == Void.class) && !schema.has("properties")) { + schema.putObject("properties"); + } if (Stream.of(schemaOptions) .noneMatch(option -> option == SchemaOption.ALLOW_ADDITIONAL_PROPERTIES_BY_DEFAULT)) { schema.put("additionalProperties", false); diff --git a/spring-ai-core/src/main/resources/META-INF/spring/aot.factories b/spring-ai-core/src/main/resources/META-INF/spring/aot.factories index 05bc046104d..b7f9b1b63c9 100644 --- a/spring-ai-core/src/main/resources/META-INF/spring/aot.factories +++ b/spring-ai-core/src/main/resources/META-INF/spring/aot.factories @@ -1,4 +1,7 @@ org.springframework.aot.hint.RuntimeHintsRegistrar=\ org.springframework.ai.aot.SpringAiCoreRuntimeHints,\ org.springframework.ai.aot.KnuddelsRuntimeHints,\ - org.springframework.ai.aot.ToolRuntimeHints \ No newline at end of file + org.springframework.ai.aot.ToolRuntimeHints + +org.springframework.beans.factory.aot.BeanRegistrationAotProcessor=\ + org.springframework.ai.aot.ToolBeanRegistrationAotProcessor diff --git a/spring-ai-core/src/test/java/org/springframework/ai/aot/ToolBeanRegistrationAotProcessorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/aot/ToolBeanRegistrationAotProcessorTests.java new file mode 100644 index 00000000000..c5eabe2e352 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/aot/ToolBeanRegistrationAotProcessorTests.java @@ -0,0 +1,89 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.aot; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.tool.annotation.Tool; +import org.springframework.aot.generate.GenerationContext; +import org.springframework.aot.hint.RuntimeHints; +import org.springframework.beans.factory.aot.BeanRegistrationAotContribution; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.beans.factory.support.RegisteredBean; +import org.springframework.beans.factory.support.RootBeanDefinition; +import org.springframework.lang.Nullable; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.springframework.aot.hint.predicate.RuntimeHintsPredicates.reflection; + +/** + * Unit tests for {@link ToolBeanRegistrationAotProcessor}. + * + * @author Thomas Vitale + */ +class ToolBeanRegistrationAotProcessorTests { + + private final GenerationContext generationContext = mock(); + + private final RuntimeHints runtimeHints = new RuntimeHints(); + + @Test + void shouldSkipNonAnnotatedClass() { + process(NonTools.class); + assertThat(this.runtimeHints.reflection().typeHints()).isEmpty(); + } + + @Test + void shouldProcessAnnotatedClass() { + process(TestTools.class); + assertThat(reflection().onType(TestTools.class)).accepts(this.runtimeHints); + } + + private void process(Class beanClass) { + when(generationContext.getRuntimeHints()).thenReturn(runtimeHints); + BeanRegistrationAotContribution contribution = createContribution(beanClass); + if (contribution != null) { + contribution.applyTo(this.generationContext, mock()); + } + } + + private static @Nullable BeanRegistrationAotContribution createContribution(Class beanClass) { + DefaultListableBeanFactory beanFactory = new DefaultListableBeanFactory(); + beanFactory.registerBeanDefinition(beanClass.getName(), new RootBeanDefinition(beanClass)); + return new ToolBeanRegistrationAotProcessor() + .processAheadOfTime(RegisteredBean.of(beanFactory, beanClass.getName())); + } + + static class TestTools { + + @Tool + String testTool() { + return "Testing"; + } + + } + + static class NonTools { + + String nonTool() { + return "More testing"; + } + + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/ToolCallbackTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/ToolCallbackTests.java index c823a9525ad..ccc464ca731 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/tool/ToolCallbackTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/ToolCallbackTests.java @@ -24,7 +24,7 @@ static class TestToolCallback implements ToolCallback { private final ToolDefinition toolDefinition; public TestToolCallback(String name) { - this.toolDefinition = ToolDefinition.builder().name(name).description(name).inputTypeSchema("{}").build(); + this.toolDefinition = ToolDefinition.builder().name(name).description(name).inputSchema("{}").build(); } @Override diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/definition/DefaultToolDefinitionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/definition/DefaultToolDefinitionTests.java index 3f203e5f13f..141ca4684c2 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/tool/definition/DefaultToolDefinitionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/definition/DefaultToolDefinitionTests.java @@ -17,7 +17,7 @@ void shouldCreateDefaultToolDefinition() { var toolDefinition = new DefaultToolDefinition("name", "description", "{}"); assertThat(toolDefinition.name()).isEqualTo("name"); assertThat(toolDefinition.description()).isEqualTo("description"); - assertThat(toolDefinition.inputTypeSchema()).isEqualTo("{}"); + assertThat(toolDefinition.inputSchema()).isEqualTo("{}"); } @Test @@ -49,17 +49,17 @@ void shouldThrowExceptionWhenDescriptionIsEmpty() { } @Test - void shouldThrowExceptionWhenInputTypeSchemaIsNull() { + void shouldThrowExceptionWhenInputSchemaIsNull() { assertThatThrownBy(() -> new DefaultToolDefinition("name", "description", null)) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("inputTypeSchema cannot be null or empty"); + .hasMessage("inputSchema cannot be null or empty"); } @Test - void shouldThrowExceptionWhenInputTypeSchemaIsEmpty() { + void shouldThrowExceptionWhenInputSchemaIsEmpty() { assertThatThrownBy(() -> new DefaultToolDefinition("name", "description", "")) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("inputTypeSchema cannot be null or empty"); + .hasMessage("inputSchema cannot be null or empty"); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/definition/ToolDefinitionTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/definition/ToolDefinitionTests.java index aac3b32fc9c..840882b6567 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/tool/definition/ToolDefinitionTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/definition/ToolDefinitionTests.java @@ -16,14 +16,10 @@ class ToolDefinitionTests { @Test void shouldCreateDefaultToolDefinitionBuilder() { - var toolDefinition = ToolDefinition.builder() - .name("name") - .description("description") - .inputTypeSchema("{}") - .build(); + var toolDefinition = ToolDefinition.builder().name("name").description("description").inputSchema("{}").build(); assertThat(toolDefinition.name()).isEqualTo("name"); assertThat(toolDefinition.description()).isEqualTo("description"); - assertThat(toolDefinition.inputTypeSchema()).isEqualTo("{}"); + assertThat(toolDefinition.inputSchema()).isEqualTo("{}"); } @Test @@ -31,7 +27,7 @@ void shouldCreateToolDefinitionFromMethod() { var toolDefinition = ToolDefinition.from(Tools.class.getDeclaredMethods()[0]); assertThat(toolDefinition.name()).isEqualTo("mySuperTool"); assertThat(toolDefinition.description()).isEqualTo("Test description"); - assertThat(toolDefinition.inputTypeSchema()).isEqualToIgnoringWhitespace(""" + assertThat(toolDefinition.inputSchema()).isEqualToIgnoringWhitespace(""" { "$schema" : "https://json-schema.org/draft/2020-12/schema", "type" : "object", diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverterTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverterTests.java index 7db22d41c9f..efb1f6c822f 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverterTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/execution/DefaultToolCallResultConverterTests.java @@ -6,7 +6,6 @@ import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * Unit tests for {@link DefaultToolCallResultConverter}. @@ -18,9 +17,9 @@ class DefaultToolCallResultConverterTests { private final DefaultToolCallResultConverter converter = new DefaultToolCallResultConverter(); @Test - void convertWithNullReturnTypeShouldThrowException() { - assertThatThrownBy(() -> converter.apply(null, null)).isInstanceOf(IllegalArgumentException.class) - .hasMessage("returnType cannot be null"); + void convertWithNullReturnTypeShouldReturn() { + String result = converter.apply(null, null); + assertThat(result).isEqualTo("null"); } @Test diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/function/FunctionToolCallbackTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/function/FunctionToolCallbackTests.java new file mode 100644 index 00000000000..caf3463470f --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/function/FunctionToolCallbackTests.java @@ -0,0 +1,268 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.function; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.model.ToolContext; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; +import org.springframework.ai.tool.execution.ToolCallResultConverter; +import org.springframework.ai.tool.metadata.ToolMetadata; +import org.springframework.ai.util.json.JsonSchemaGenerator; +import org.springframework.core.ParameterizedTypeReference; + +import java.util.List; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Unit tests for {@link FunctionToolCallback}. + * + * @author Thomas Vitale + */ +class FunctionToolCallbackTests { + + @Test + void constructorShouldValidateRequiredParameters() { + ToolDefinition toolDefinition = mock(ToolDefinition.class); + ToolMetadata toolMetadata = mock(ToolMetadata.class); + BiFunction toolFunction = (input, context) -> input; + + assertThatThrownBy(() -> new FunctionToolCallback<>(null, toolMetadata, String.class, toolFunction, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolDefinition cannot be null"); + + assertThatThrownBy(() -> new FunctionToolCallback<>(toolDefinition, toolMetadata, null, toolFunction, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolInputType cannot be null"); + + assertThatThrownBy(() -> new FunctionToolCallback<>(toolDefinition, toolMetadata, String.class, null, null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolFunction cannot be null"); + } + + @Test + void callShouldExecuteToolFunctionAndConvertResult() { + ToolDefinition toolDefinition = mock(ToolDefinition.class); + when(toolDefinition.name()).thenReturn("test-tool"); + BiFunction toolFunction = (input, + context) -> new TestResponse(input.input()); + + ToolCallback callback = FunctionToolCallback.builder("test-tool", toolFunction) + .inputType(TestRequest.class) + .build(); + + String result = callback.call(""" + { + "input": "test input" + } + """, mock(ToolContext.class)); + + assertThat(result).isEqualToIgnoringWhitespace(""" + { + "output": "test input" + } + """); + } + + @Test + void callShouldValidateInput() { + ToolCallback callback = FunctionToolCallback.builder("test-tool", (input, context) -> input) + .inputType(String.class) + .build(); + + assertThatThrownBy(() -> callback.call("")).isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolInput cannot be null or empty"); + + assertThatThrownBy(() -> callback.call(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolInput cannot be null or empty"); + } + + @Test + void callWithoutContextShouldWorkCorrectly() { + BiFunction toolFunction = (input, + context) -> new TestResponse(input.input()); + + ToolCallback callback = FunctionToolCallback.builder("test-tool", toolFunction) + .inputType(TestRequest.class) + .build(); + + String result = callback.call(""" + { + "input": "test input" + } + """); + + assertThat(result).isEqualToIgnoringWhitespace(""" + { + "output": "test input" + } + """); + } + + // Builder + + @Test + void builderShouldCreateInstanceWithAllProperties() { + ToolMetadata toolMetadata = mock(ToolMetadata.class); + BiFunction toolFunction = (input, context) -> input; + ToolCallResultConverter resultConverter = mock(ToolCallResultConverter.class); + + ToolCallback callback = FunctionToolCallback.builder("testTool", toolFunction) + .description("A test tool") + .inputSchema(JsonSchemaGenerator.generateForType(String.class)) + .inputType(String.class) + .toolMetadata(toolMetadata) + .toolCallResultConverter(resultConverter) + .build(); + + assertThat(callback.getToolDefinition().name()).isEqualTo("testTool"); + assertThat(callback.getToolDefinition().description()).isEqualTo("A test tool"); + assertThat(callback.getToolMetadata()).isEqualTo(toolMetadata); + } + + @Test + void builderShouldCreateInstanceWithCustomSchema() { + ToolMetadata toolMetadata = mock(ToolMetadata.class); + BiFunction toolFunction = (input, context) -> input; + ToolCallResultConverter resultConverter = mock(ToolCallResultConverter.class); + + ToolCallback callback = FunctionToolCallback.builder("testTool", toolFunction) + .description("A test tool") + // Special schema generation required by Vertex AI. + .inputSchema(JsonSchemaGenerator.generateForType(String.class, + JsonSchemaGenerator.SchemaOption.UPPER_CASE_TYPE_VALUES)) + .inputType(String.class) + .toolMetadata(toolMetadata) + .toolCallResultConverter(resultConverter) + .build(); + + assertThat(callback.getToolDefinition().name()).isEqualTo("testTool"); + assertThat(callback.getToolDefinition().description()).isEqualTo("A test tool"); + assertThat(callback.getToolMetadata()).isEqualTo(toolMetadata); + } + + @Test + void whenBuilderWithRequiredPropertiesThenReturn() { + var builder = FunctionToolCallback.builder("test-tool", (input, context) -> input); + assertThat(builder).isNotNull(); + } + + @Test + void whenToolNameIsNullThenThrow() { + assertThatThrownBy(() -> FunctionToolCallback.builder(null, (input, context) -> input)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be null or empty"); + } + + @Test + void whenToolNameIsEmptyThenThrow() { + assertThatThrownBy(() -> FunctionToolCallback.builder("", (input, context) -> input)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("name cannot be null or empty"); + } + + @Test + void whenBuildingFromBiFunctionThenReturn() { + var builder = FunctionToolCallback.builder("test-tool", (input, context) -> input); + assertThat(builder).isNotNull(); + } + + @Test + void whenBuildingFromNullBiFunctionThenReturn() { + assertThatThrownBy(() -> FunctionToolCallback.builder("test-tool", (BiFunction) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("toolFunction cannot be null"); + } + + @Test + void whenBuildingFromFunctionThenReturn() { + var builder = FunctionToolCallback.builder("test-tool", (input) -> input); + assertThat(builder).isNotNull(); + } + + @Test + void whenBuildingFromNullFunctionThenReturn() { + assertThatThrownBy(() -> FunctionToolCallback.builder("test-tool", (Function) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("function cannot be null"); + } + + @Test + void whenBuildingFromSupplierThenReturn() { + var builder = FunctionToolCallback.builder("test-tool", () -> "Hello"); + assertThat(builder).isNotNull(); + } + + @Test + void whenBuildingFromNullSupplierThenReturn() { + assertThatThrownBy(() -> FunctionToolCallback.builder("test-tool", (Supplier) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("supplier cannot be null"); + } + + @Test + void whenBuildingFromConsumerThenReturn() { + var builder = FunctionToolCallback.builder("test-tool", (input) -> null); + assertThat(builder).isNotNull(); + } + + @Test + void whenBuildingFromNullConsumerThenReturn() { + assertThatThrownBy(() -> FunctionToolCallback.builder("test-tool", (Consumer) null)) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("consumer cannot be null"); + } + + @Test + void whenInputTypeIsNullThenThrow() { + assertThatThrownBy(() -> FunctionToolCallback.builder("test-tool", (input, context) -> input).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessage("inputType cannot be null"); + } + + @Test + void whenToolDescriptionIsNullThenComputeFromName() { + ToolCallback callback = FunctionToolCallback.builder("mySuperTestTool", (input, context) -> input) + .inputType(String.class) + .build(); + assertThat(callback.getToolDefinition().description()).isEqualTo("my super test tool"); + } + + @Test + void whenInputTypeIsGenericThenReturn() { + ToolCallback callback = FunctionToolCallback.builder("mySuperTestTool", (input, context) -> input) + .inputType(new ParameterizedTypeReference>() { + }) + .build(); + assertThat(callback).isNotNull(); + } + + public record TestRequest(String input) { + } + + public record TestResponse(String output) { + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/utils/ToolUtilsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/utils/ToolUtilsTests.java index 1a35c8799e4..90295befc2d 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/tool/utils/ToolUtilsTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/utils/ToolUtilsTests.java @@ -9,6 +9,7 @@ import org.springframework.ai.tool.util.ToolUtils; import java.lang.reflect.Method; +import java.lang.reflect.Type; import java.util.List; import static org.assertj.core.api.Assertions.assertThat; @@ -68,6 +69,12 @@ void shouldGetToolDescriptionFromAnnotation() throws Exception { assertThat(ToolUtils.getToolDescription(method)).isEqualTo("Custom description"); } + @Test + void shouldGetToolDescriptionFromName() { + String description = ToolUtils.getToolDescriptionFromName("mySuperSpecialTool"); + assertThat(description).isEqualTo("my super special tool"); + } + @Test void shouldGetMethodNameWhenNoCustomDescriptionInAnnotation() throws Exception { Method method = TestTools.class.getMethod("toolWithoutCustomDescription"); @@ -119,7 +126,7 @@ static class TestToolCallback implements ToolCallback { private final ToolDefinition toolDefinition; public TestToolCallback(String name) { - this.toolDefinition = ToolDefinition.builder().name(name).description(name).inputTypeSchema("{}").build(); + this.toolDefinition = ToolDefinition.builder().name(name).description(name).inputSchema("{}").build(); } @Override @@ -175,8 +182,8 @@ public void camelCaseMethodWithoutAnnotation() { public static class CustomToolCallResultConverter implements ToolCallResultConverter { @Override - public String apply(Object result, Class returnType) { - return returnType.getName(); + public String apply(Object result, Type returnType) { + return returnType == null ? "null" : returnType.getTypeName(); } } @@ -188,8 +195,8 @@ private InvalidToolCallResultConverter() { } @Override - public String apply(Object result, Class returnType) { - return returnType.getName(); + public String apply(Object result, Type returnType) { + return returnType == null ? "null" : returnType.getTypeName(); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/util/json/JsonParserTests.java b/spring-ai-core/src/test/java/org/springframework/ai/util/json/JsonParserTests.java index e3552f69407..a1879a4e9c5 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/util/json/JsonParserTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/util/json/JsonParserTests.java @@ -3,6 +3,8 @@ import com.fasterxml.jackson.core.type.TypeReference; import org.junit.jupiter.api.Test; +import java.lang.reflect.Type; + import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -108,6 +110,20 @@ void fromJsonToObjectWithUnknownProperty() { assertThat(object.name).isEqualTo("James"); } + @Test + void fromJsonToObjectWithType() { + var json = """ + { + "name" : "John", + "age" : 30 + } + """; + TestRecord object = JsonParser.fromJson(json, (Type) TestRecord.class); + assertThat(object).isNotNull(); + assertThat(object.name).isEqualTo("John"); + assertThat(object.age).isEqualTo(30); + } + @Test void fromObjectToJson() { var object = new TestRecord("John", 30); diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/FunctionToolCallbackTests.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/FunctionToolCallbackTests.java new file mode 100644 index 00000000000..27ee2e0612f --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/tool/FunctionToolCallbackTests.java @@ -0,0 +1,275 @@ +/* + * Copyright 2023-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.integration.tests.tool; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.integration.tests.TestApplication; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.tool.function.FunctionToolCallback; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.context.annotation.Description; +import org.springframework.context.annotation.Import; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Consumer; +import java.util.function.Function; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link FunctionToolCallback}. + * + * @author Thomas Vitale + */ +@SpringBootTest(classes = TestApplication.class) +@Import(FunctionToolCallbackTests.Tools.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +public class FunctionToolCallbackTests { + + // @formatter:off + + private static final Logger logger = LoggerFactory.getLogger(FunctionToolCallbackTests.class); + + @Autowired + OpenAiChatModel openAiChatModel; + + @Test + void chatVoidInputFromBean() { + var content = ChatClient.builder(this.openAiChatModel) + .build() + .prompt() + .user("Welcome the users to the library") + .tools(Tools.WELCOME) + .call() + .content(); + assertThat(content).isNotEmpty(); + } + + @Test + void chatVoidInputFromCallback() { + var content = ChatClient.builder(this.openAiChatModel) + .build() + .prompt() + .user("Welcome the users to the library") + .toolCallbacks(FunctionToolCallback.builder("sayWelcome", (input) -> { + logger.info("CALLBACK - Welcoming users to the library"); + }) + .description("Welcome users to the library") + .inputType(Void.class) + .build()) + .call() + .content(); + assertThat(content).isNotEmpty(); + } + + @Test + void chatVoidOutputFromBean() { + var content = ChatClient.builder(this.openAiChatModel) + .build() + .prompt() + .user("Welcome %s to the library".formatted("James Bond")) + .tools(Tools.WELCOME_USER) + .call() + .content(); + assertThat(content).isNotEmpty(); + } + + @Test + void chatVoidOutputFromCallback() { + var content = ChatClient.builder(this.openAiChatModel) + .build() + .prompt() + .user("Welcome %s to the library".formatted("James Bond")) + .toolCallbacks(FunctionToolCallback.builder("welcomeUser", (user) -> { + logger.info("CALLBACK - Welcoming {} to the library", ((User) user).name()); + }) + .description("Welcome a specific user to the library") + .inputType(User.class) + .build()) + .call() + .content(); + assertThat(content).contains("Bond"); + } + + @Test + void chatSingleFromBean() { + var content = ChatClient.builder(this.openAiChatModel) + .build() + .prompt() + .user("What books written by %s are available in the library?".formatted("J.R.R. Tolkien")) + .tools(Tools.BOOKS_BY_AUTHOR) + .call() + .content(); + assertThat(content).isNotEmpty() + .contains("The Hobbit") + .contains("The Lord of The Rings") + .contains("The Silmarillion"); + } + + @Test + void chatSingleFromCallback() { + Function> function = author -> { + logger.info("CALLBACK - Getting books by author: {}", author.name()); + return new BookService().getBooksByAuthor(author); + }; + var content = ChatClient.builder(this.openAiChatModel) + .build() + .prompt() + .user("What books written by %s are available in the library?".formatted("J.R.R. Tolkien")) + .toolCallbacks(FunctionToolCallback.builder("availableBooksByAuthor", function) + .description("Get the list of books written by the given author available in the library") + .inputType(Author.class) + .build()) + .call() + .content(); + assertThat(content).isNotEmpty() + .contains("The Hobbit") + .contains("The Lord of The Rings") + .contains("The Silmarillion"); + } + + @Test + void chatListFromBean() { + var content = ChatClient.builder(this.openAiChatModel) + .build() + .prompt() + .user("What authors wrote the books %s and %s available in the library?".formatted("The Hobbit", "Narnia")) + .tools(Tools.AUTHORS_BY_BOOKS) + .call() + .content(); + assertThat(content).isNotEmpty().contains("J.R.R. Tolkien").contains("C.S. Lewis"); + } + + @Test + void chatListFromCallback() { + Function> function = books -> { + logger.info("CALLBACK - Getting authors by books: {}", books.books().stream().map(Book::title).toList()); + return new BookService().getAuthorsByBook(books.books()); + }; + var content = ChatClient.builder(this.openAiChatModel) + .build() + .prompt() + .user("What authors wrote the books %s and %s available in the library?".formatted("The Hobbit", "Narnia")) + .toolCallbacks(FunctionToolCallback.builder("authorsByAvailableBooks", function) + .description("Get the list of authors who wrote the given books available in the library") + .inputType(Books.class) + .build()) + .call() + .content(); + assertThat(content).isNotEmpty().contains("J.R.R. Tolkien").contains("C.S. Lewis"); + } + + @Configuration(proxyBeanMethods = false) + static class Tools { + + public static final String AUTHORS_BY_BOOKS = "authorsByBooks"; + + public static final String BOOKS_BY_AUTHOR = "booksByAuthor"; + + public static final String WELCOME = "welcome"; + + public static final String WELCOME_USER = "welcomeUser"; + + private static final Logger logger = LoggerFactory.getLogger(Tools.class); + + private final BookService bookService = new BookService(); + + @Bean(WELCOME) + @Description("Welcome users to the library") + Consumer welcome() { + return (input) -> logger.info("Welcoming users to the library"); + } + + @Bean(WELCOME_USER) + @Description("Welcome a specific user to the library") + Consumer welcomeUser() { + return user -> logger.info("Welcoming {} to the library", user.name()); + } + + @Bean(BOOKS_BY_AUTHOR) + @Description("Get the list of books written by the given author available in the library") + Function> booksByAuthor() { + return author -> { + logger.info("Getting books by author: {}", author.name()); + return bookService.getBooksByAuthor(author); + }; + } + + @Bean(AUTHORS_BY_BOOKS) + @Description("Get the list of authors who wrote the given books available in the library") + Function> authorsByBooks() { + return books -> { + logger.info("Getting authors by books: {}", books.books().stream().map(Book::title).toList()); + return bookService.getAuthorsByBook(books.books()); + }; + } + + } + + public record User(String name) { + } + + public record Author(String name) { + } + + public record Authors(List authors) { + } + + public record Book(String title, String author) { + } + + public record Books(List books) { + } + + static class BookService { + + private static final Map books = new ConcurrentHashMap<>(); + + static { + books.put(1, new Book("His Dark Materials", "Philip Pullman")); + books.put(2, new Book("Narnia", "C.S. Lewis")); + books.put(3, new Book("The Hobbit", "J.R.R. Tolkien")); + books.put(4, new Book("The Lord of The Rings", "J.R.R. Tolkien")); + books.put(5, new Book("The Silmarillion", "J.R.R. Tolkien")); + } + + public List getBooksByAuthor(Author author) { + return books.values().stream().filter(book -> author.name().equals(book.author())).toList(); + } + + public List getAuthorsByBook(List booksToSearch) { + return books.values() + .stream() + .filter(book -> booksToSearch.stream().anyMatch(b -> b.title().equals(book.title()))) + .map(book -> new Author(book.author())) + .toList(); + } + + } + + // @formatter:on + +}