diff --git a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java index 2719569dcce..0bbc99be341 100644 --- a/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java +++ b/models/spring-ai-bedrock-converse/src/test/java/org/springframework/ai/bedrock/converse/BedrockConverseChatClientIT.java @@ -222,6 +222,23 @@ void functionCallTest() { assertThat(response).contains("30", "10", "15"); } + @Test + void functionCallWithAdvisorTest() { + + // @formatter:off + String response = ChatClient.create(this.chatModel) + .prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.") + .function("getCurrentWeather", "Get the weather in location", new MockWeatherService()) + .advisors(new SimpleLoggerAdvisor()) + .call() + .content(); + // @formatter:on + + logger.info("Response: {}", response); + + assertThat(response).contains("30", "10", "15"); + } + @Test void defaultFunctionCallTest() { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientProxyFunctionCallsIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientProxyFunctionCallsIT.java new file mode 100644 index 00000000000..709398dee78 --- /dev/null +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/client/OpenAiChatClientProxyFunctionCallsIT.java @@ -0,0 +1,185 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.chat.client; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Message; +import org.springframework.ai.chat.messages.ToolResponseMessage; +import org.springframework.ai.chat.messages.UserMessage; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.chat.model.Generation; +import org.springframework.ai.model.ModelOptionsUtils; +import org.springframework.ai.model.function.FunctionCallback; +import org.springframework.ai.model.function.ToolCallHelper; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.openai.OpenAiChatOptions; +import org.springframework.ai.openai.OpenAiTestConfiguration; +import org.springframework.ai.openai.api.OpenAiApi; +import org.springframework.ai.openai.testutils.AbstractIT; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.io.Resource; +import org.springframework.test.context.ActiveProfiles; +import org.springframework.util.CollectionUtils; + +import static org.assertj.core.api.Assertions.assertThat; + +@SpringBootTest(classes = OpenAiTestConfiguration.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+") +@ActiveProfiles("logging-test") +class OpenAiChatClientProxyFunctionCallsIT extends AbstractIT { + + private static final Logger logger = LoggerFactory.getLogger(OpenAiChatClientMultipleFunctionCallsIT.class); + + @Value("classpath:/prompts/system-message.st") + private Resource systemTextResource; + + FunctionCallback functionDefinition = new ToolCallHelper.FunctionDefinition("getWeatherInLocation", + "Get the weather in location", """ + { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["C", "F"] + } + }, + "required": ["location", "unit"] + } + """); + + @Autowired + private OpenAiChatModel chatModel; + + // Helper class that reuses some of the {@link AbstractToolCallSupport} functionality + // to help to implement the function call handling logic on the client side. + private ToolCallHelper toolCallHelper = new ToolCallHelper(); + + // Function which will be called by the AI model. + private String getWeatherInLocation(String location, String unit) { + + double temperature = 0; + + if (location.contains("Paris")) { + temperature = 15; + } + else if (location.contains("Tokyo")) { + temperature = 10; + } + else if (location.contains("San Francisco")) { + temperature = 30; + } + + return String.format("The weather in %s is %s%s", location, temperature, unit); + } + + @Test + void toolProxyFunctionCall() throws JsonMappingException, JsonProcessingException { + + List messages = List + .of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?")); + + boolean isToolCall = false; + + ChatResponse chatResponse = null; + + var chatClient = ChatClient.builder(this.chatModel).build(); + + do { + + chatResponse = chatClient.prompt() + .messages(messages) + .functions(this.functionDefinition) + .options(OpenAiChatOptions.builder().withProxyToolCalls(true).build()) + .call() + .chatResponse(); + + // Note that the tool call check could be platform specific because the finish + // reasons. + isToolCall = this.toolCallHelper.isToolCall(chatResponse, + Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(), + OpenAiApi.ChatCompletionFinishReason.STOP.name())); + + if (isToolCall) { + + Optional toolCallGeneration = chatResponse.getResults() + .stream() + .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) + .findFirst(); + + assertThat(toolCallGeneration).isNotEmpty(); + + AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); + + List toolResponses = new ArrayList<>(); + + for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { + + var functionName = toolCall.name(); + + assertThat(functionName).isEqualTo("getWeatherInLocation"); + + String functionArguments = toolCall.arguments(); + + @SuppressWarnings("unchecked") + Map argumentsMap = new ObjectMapper().readValue(functionArguments, Map.class); + + String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(), + argumentsMap.get("unit").toString()); + + toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), functionName, + ModelOptionsUtils.toJsonString(functionResponse))); + } + + ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of()); + + messages = this.toolCallHelper.buildToolCallConversation(messages, assistantMessage, + toolMessageResponse); + + assertThat(messages).isNotEmpty(); + + // prompt = new Prompt(toolCallConversation, prompt.getOptions()); + } + } + while (isToolCall); + + logger.info("Response: {}", chatResponse); + + assertThat(chatResponse.getResult().getOutput().getContent()).contains("30", "10", "15"); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java index 7bc48c7723e..e79f84ccc38 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequest.java @@ -84,7 +84,8 @@ public record AdvisedRequest( public AdvisedRequest { Assert.notNull(chatModel, "chatModel cannot be null"); - Assert.hasText(userText, "userText cannot be null or empty"); + Assert.isTrue(StringUtils.hasText(userText) || !CollectionUtils.isEmpty(messages), + "userText cannot be null or empty unless messages are provided and contain Tool Response message."); Assert.notNull(media, "media cannot be null"); Assert.noNullElements(media, "media cannot contain null elements"); Assert.notNull(functionNames, "functionNames cannot be null"); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java index 8d392814f4f..4f2d4415ec2 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/api/AdvisedRequestTests.java @@ -54,7 +54,8 @@ void whenUserTextIsNullThenThrows() { assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), null, null, null, List.of(), List.of(), List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("userText cannot be null or empty"); + .hasMessage( + "userText cannot be null or empty unless messages are provided and contain Tool Response message."); } @Test @@ -62,7 +63,8 @@ void whenUserTextIsEmptyThenThrows() { assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "", null, null, List.of(), List.of(), List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of())) .isInstanceOf(IllegalArgumentException.class) - .hasMessage("userText cannot be null or empty"); + .hasMessage( + "userText cannot be null or empty unless messages are provided and contain Tool Response message."); } @Test diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java index 4cb3182090b..88f5068025a 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatProperties.java @@ -16,7 +16,6 @@ package org.springframework.ai.autoconfigure.bedrock.converse; -import org.springframework.ai.bedrock.converse.BedrockProxyChatModel; import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions; import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.NestedConfigurationProperty; @@ -38,12 +37,6 @@ public class BedrockConverseProxyChatProperties { */ private boolean enabled = true; - /** - * The generative id to use. See the {@link BedrockProxyChatModel} for the supported - * models. - */ - private String model = "anthropic.claude-3-5-sonnet-20240620-v1:0"; - @NestedConfigurationProperty private PortableFunctionCallingOptions options = PortableFunctionCallingOptions.builder() .withTemperature(0.7) @@ -59,14 +52,6 @@ public void setEnabled(boolean enabled) { this.enabled = enabled; } - public String getModel() { - return this.model; - } - - public void setModel(String model) { - this.model = model; - } - public PortableFunctionCallingOptions getOptions() { return this.options; } diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatPropertiesTests.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatPropertiesTests.java index 09cb19c4b7b..4b4bb809a44 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatPropertiesTests.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/bedrock/converse/BedrockConverseProxyChatPropertiesTests.java @@ -71,12 +71,12 @@ public void chatCompletionDisabled() { .run(context -> assertThat(context.getBeansOfType(BedrockConverseProxyChatProperties.class)).isNotEmpty()); // Explicitly enable the chat auto-configuration. - new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.converse.chat..enabled=true") + new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.converse.chat.enabled=true") .withConfiguration(AutoConfigurations.of(BedrockConverseProxyChatAutoConfiguration.class)) .run(context -> assertThat(context.getBeansOfType(BedrockConverseProxyChatProperties.class)).isNotEmpty()); // Explicitly disable the chat auto-configuration. - new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.converse.chat..enabled=false") + new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.converse.chat.enabled=false") .withConfiguration(AutoConfigurations.of(BedrockConverseProxyChatAutoConfiguration.class)) .run(context -> assertThat(context.getBeansOfType(BedrockConverseProxyChatProperties.class)).isEmpty()); }