From 311f0d0d385815837a5764dd60a493aac0fdfa52 Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 1 May 2025 19:18:06 +0300 Subject: [PATCH] feat: Improve validation in MethodToolCallbackProvider This ensures that validation errors are caught early during object construction rather than later when methods are called, providing better error feedback. - Add validation for tool-annotated methods during construction - Validate duplicate tool names in constructor instead of only at getToolCallbacks() time - Add comprehensive test suite for MethodToolCallbackProvider Signed-off-by: Christian Tzolov --- .../method/MethodToolCallbackProvider.java | 20 +++ .../MethodToolCallbackProviderTests.java | 140 ++++++++++++++++++ 2 files changed, 160 insertions(+) create mode 100644 spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackProviderTests.java diff --git a/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java b/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java index 91b8eb5cc6d..3178546c373 100644 --- a/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java +++ b/spring-ai-model/src/main/java/org/springframework/ai/tool/method/MethodToolCallbackProvider.java @@ -19,6 +19,7 @@ import java.lang.reflect.Method; import java.util.Arrays; import java.util.List; +import java.util.Optional; import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Supplier; @@ -44,6 +45,7 @@ * {@link Tool}-annotated methods. * * @author Thomas Vitale + * @author Christian Tzolov * @since 1.0.0 */ public final class MethodToolCallbackProvider implements ToolCallbackProvider { @@ -55,7 +57,25 @@ public final class MethodToolCallbackProvider implements ToolCallbackProvider { private MethodToolCallbackProvider(List toolObjects) { Assert.notNull(toolObjects, "toolObjects cannot be null"); Assert.noNullElements(toolObjects, "toolObjects cannot contain null elements"); + assertToolAnnotatedMethodsPresent(toolObjects); this.toolObjects = toolObjects; + validateToolCallbacks(getToolCallbacks()); + } + + private void assertToolAnnotatedMethodsPresent(List toolObjects) { + + for (Object toolObject : toolObjects) { + List toolMethods = Stream + .of(ReflectionUtils.getDeclaredMethods( + AopUtils.isAopProxy(toolObject) ? AopUtils.getTargetClass(toolObject) : toolObject.getClass())) + .filter(toolMethod -> toolMethod.isAnnotationPresent(Tool.class)) + .filter(toolMethod -> !isFunctionalType(toolMethod)) + .toList(); + + if (toolMethods.isEmpty()) { + throw new IllegalStateException("No @Tool annotated methods found in " + toolObject); + } + } } @Override diff --git a/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackProviderTests.java b/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackProviderTests.java new file mode 100644 index 00000000000..cdc9fccedfb --- /dev/null +++ b/spring-ai-model/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackProviderTests.java @@ -0,0 +1,140 @@ +/* + * Copyright 2025-2025 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.tool.method; + +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.function.Supplier; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.tool.annotation.Tool; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** + * Unit tests for {@link MethodToolCallbackProvider}. + * + * @author Christian Tzolov + */ +class MethodToolCallbackProviderTests { + + @Test + void whenToolObjectHasToolAnnotatedMethodThenSucceed() { + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder() + .toolObjects(new ValidToolObject()) + .build(); + + assertThat(provider.getToolCallbacks()).hasSize(1); + assertThat(provider.getToolCallbacks()[0].getToolDefinition().name()).isEqualTo("validTool"); + } + + @Test + void whenToolObjectHasNoToolAnnotatedMethodThenThrow() { + assertThatThrownBy( + () -> MethodToolCallbackProvider.builder().toolObjects(new NoToolAnnotatedMethodObject()).build()) + .isInstanceOf(IllegalStateException.class) + .hasMessageContaining("No @Tool annotated methods found in"); + } + + @Test + void whenToolObjectHasOnlyFunctionalTypeToolMethodsThenThrow() { + assertThatThrownBy(() -> MethodToolCallbackProvider.builder() + .toolObjects(new OnlyFunctionalTypeToolMethodsObject()) + .build()).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("No @Tool annotated methods found in"); + } + + @Test + void whenToolObjectHasMixOfValidAndFunctionalTypeToolMethodsThenSucceed() { + MethodToolCallbackProvider provider = MethodToolCallbackProvider.builder() + .toolObjects(new MixedToolMethodsObject()) + .build(); + + assertThat(provider.getToolCallbacks()).hasSize(1); + assertThat(provider.getToolCallbacks()[0].getToolDefinition().name()).isEqualTo("validTool"); + } + + @Test + void whenMultipleToolObjectsWithSameToolNameThenThrow() { + assertThatThrownBy(() -> MethodToolCallbackProvider.builder() + .toolObjects(new ValidToolObject(), new DuplicateToolNameObject()) + .build()).isInstanceOf(IllegalStateException.class) + .hasMessageContaining("Multiple tools with the same name (validTool) found in sources"); + } + + static class ValidToolObject { + + @Tool + public String validTool() { + return "Valid tool result"; + } + + } + + static class NoToolAnnotatedMethodObject { + + public String notATool() { + return "Not a tool"; + } + + } + + static class OnlyFunctionalTypeToolMethodsObject { + + @Tool + public Function functionTool() { + return input -> "Function result: " + input; + } + + @Tool + public Supplier supplierTool() { + return () -> "Supplier result"; + } + + @Tool + public Consumer consumerTool() { + return input -> System.out.println("Consumer received: " + input); + } + + } + + static class MixedToolMethodsObject { + + @Tool + public String validTool() { + return "Valid tool result"; + } + + @Tool + public Function functionTool() { + return input -> "Function result: " + input; + } + + } + + static class DuplicateToolNameObject { + + @Tool + public String validTool() { + return "Duplicate tool result"; + } + + } + +}