From 82d98b6b485a8d98caf7c9dd17ba0b9f2ab011bd Mon Sep 17 00:00:00 2001 From: Christian Tzolov Date: Thu, 27 Feb 2025 11:01:55 +0100 Subject: [PATCH] fix: reverse ToolContext validation logic in MethodToolCallback Changes the validation logic in MethodToolCallback to check if a ToolContext is required by the method but not provided, rather than checking if a ToolContext is provided but not supported by the method. This ensures methods that expect a ToolContext parameter receive one. Updates tests cases to reflect the new validation logic Resolves #2337 Signed-off-by: Christian Tzolov --- ...picChatClientMethodInvokingFunctionCallbackIT.java | 7 +++---- ...nAiChatClientMethodInvokingFunctionCallbackIT.java | 11 +++++------ .../ai/tool/method/MethodToolCallback.java | 6 +++--- .../ai/tool/method/MethodToolCallbackTests.java | 10 ++++------ 4 files changed, 15 insertions(+), 19 deletions(-) diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientMethodInvokingFunctionCallbackIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientMethodInvokingFunctionCallbackIT.java index 8828c7ffe9b..b30d84017be 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientMethodInvokingFunctionCallbackIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/client/AnthropicChatClientMethodInvokingFunctionCallbackIT.java @@ -192,13 +192,13 @@ void methodGetWeatherToolContext() { } @Test - void methodGetWeatherToolContextButNonContextMethod() { + void methodGetWeatherWithContextMethodButMissingContext() { TestFunctionClass targetObject = new TestFunctionClass(); // @formatter:off var toolMethod = ReflectionUtils.findMethod( - TestFunctionClass.class, "getWeatherNonStatic", String.class, Unit.class); + TestFunctionClass.class, "getWeatherWithContext", String.class, Unit.class, ToolContext.class); assertThatThrownBy(() -> ChatClient.create(this.chatModel).prompt() .user("What's the weather like in San Francisco, Tokyo, and Paris? Use Celsius.") @@ -209,11 +209,10 @@ void methodGetWeatherToolContextButNonContextMethod() { .toolMethod(toolMethod) .toolObject(targetObject) .build()) - .toolContext(Map.of("tool", "value")) .call() .content()) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("ToolContext is not supported by the method as an argument"); + .hasMessage("ToolContext is required by the method as an argument"); // @formatter:on } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMethodInvokingFunctionCallbackIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMethodInvokingFunctionCallbackIT.java index 6aba6982011..5ead8676a8f 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMethodInvokingFunctionCallbackIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientMethodInvokingFunctionCallbackIT.java @@ -165,12 +165,12 @@ void methodGetWeatherToolContext() { } @Test - void methodGetWeatherToolContextButNonContextMethod() { + void methodGetWeatherToolContextButMissingContextArgument() { TestFunctionClass targetObject = new TestFunctionClass(); - var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherNonStatic", String.class, - Unit.class); + var toolMethod = ReflectionUtils.findMethod(TestFunctionClass.class, "getWeatherWithContext", String.class, + Unit.class, ToolContext.class); // @formatter:off assertThatThrownBy(() -> ChatClient.create(this.chatModel).prompt() @@ -181,12 +181,11 @@ void methodGetWeatherToolContextButNonContextMethod() { .build()) .toolMethod(toolMethod) .toolObject(targetObject) - .build()) - .toolContext(Map.of("tool", "value")) + .build()) .call() .content()) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("ToolContext is not supported by the method as an argument"); + .hasMessage("ToolContext is required by the method as an argument"); // @formatter:on } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java index cd740a26905..e5764664781 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/tool/method/MethodToolCallback.java @@ -115,11 +115,11 @@ public String call(String toolInput, @Nullable ToolContext toolContext) { } private void validateToolContextSupport(@Nullable ToolContext toolContext) { - var isToolContextRequired = toolContext != null && !CollectionUtils.isEmpty(toolContext.getContext()); + var isNonEmptyToolContextProvided = toolContext != null && !CollectionUtils.isEmpty(toolContext.getContext()); var isToolContextAcceptedByMethod = Stream.of(toolMethod.getParameterTypes()) .anyMatch(type -> ClassUtils.isAssignable(type, ToolContext.class)); - if (isToolContextRequired && !isToolContextAcceptedByMethod) { - throw new IllegalArgumentException("ToolContext is not supported by the method as an argument"); + if (isToolContextAcceptedByMethod && !isNonEmptyToolContextProvided) { + throw new IllegalArgumentException("ToolContext is required by the method as an argument"); } } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackTests.java b/spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackTests.java index a21810c3d2a..cab199bbbb0 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/tool/method/MethodToolCallbackTests.java @@ -69,8 +69,8 @@ void shouldHandleToolContextWhenSupported() { } @Test - void shouldThrowExceptionWhenToolContextNotSupported() { - Method toolMethod = getMethod("publicMethod", PublicTools.class); + void shouldThrowExceptionWhenToolContextArgumentIsMissing() { + Method toolMethod = getMethod("methodWithToolContext", ToolContextTools.class); MethodToolCallback callback = MethodToolCallback.builder() .toolDefinition(ToolDefinition.from(toolMethod)) .toolMetadata(ToolMetadata.from(toolMethod)) @@ -78,14 +78,12 @@ void shouldThrowExceptionWhenToolContextNotSupported() { .toolObject(new PublicTools()) .build(); - ToolContext toolContext = new ToolContext(Map.of("key", "value")); - assertThatThrownBy(() -> callback.call(""" { "input": "test" } - """, toolContext)).isInstanceOf(IllegalArgumentException.class) - .hasMessageContaining("ToolContext is not supported"); + """)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("ToolContext is required by the method as an argument"); } @Test