Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,23 @@ void functionCallTest() {
assertThat(response).contains("30", "10", "15");
}

@Test
void functionCallWithAdvisorTest() {

// @formatter:off
String response = ChatClient.create(this.chatModel)
.prompt("What's the weather like in San Francisco, Tokyo, and Paris? Return the temperature in Celsius.")
.function("getCurrentWeather", "Get the weather in location", new MockWeatherService())
.advisors(new SimpleLoggerAdvisor())
.call()
.content();
// @formatter:on

logger.info("Response: {}", response);

assertThat(response).contains("30", "10", "15");
}

@Test
void defaultFunctionCallTest() {

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.openai.chat.client;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.ToolResponseMessage;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.model.ChatResponse;
import org.springframework.ai.chat.model.Generation;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.ToolCallHelper;
import org.springframework.ai.openai.OpenAiChatModel;
import org.springframework.ai.openai.OpenAiChatOptions;
import org.springframework.ai.openai.OpenAiTestConfiguration;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.testutils.AbstractIT;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.core.io.Resource;
import org.springframework.test.context.ActiveProfiles;
import org.springframework.util.CollectionUtils;

import static org.assertj.core.api.Assertions.assertThat;

@SpringBootTest(classes = OpenAiTestConfiguration.class)
@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".+")
@ActiveProfiles("logging-test")
class OpenAiChatClientProxyFunctionCallsIT extends AbstractIT {

private static final Logger logger = LoggerFactory.getLogger(OpenAiChatClientMultipleFunctionCallsIT.class);

@Value("classpath:/prompts/system-message.st")
private Resource systemTextResource;

FunctionCallback functionDefinition = new ToolCallHelper.FunctionDefinition("getWeatherInLocation",
"Get the weather in location", """
{
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA"
},
"unit": {
"type": "string",
"enum": ["C", "F"]
}
},
"required": ["location", "unit"]
}
""");

@Autowired
private OpenAiChatModel chatModel;

// Helper class that reuses some of the {@link AbstractToolCallSupport} functionality
// to help to implement the function call handling logic on the client side.
private ToolCallHelper toolCallHelper = new ToolCallHelper();

// Function which will be called by the AI model.
private String getWeatherInLocation(String location, String unit) {

double temperature = 0;

if (location.contains("Paris")) {
temperature = 15;
}
else if (location.contains("Tokyo")) {
temperature = 10;
}
else if (location.contains("San Francisco")) {
temperature = 30;
}

return String.format("The weather in %s is %s%s", location, temperature, unit);
}

@Test
void toolProxyFunctionCall() throws JsonMappingException, JsonProcessingException {

List<Message> messages = List
.of(new UserMessage("What's the weather like in San Francisco, Tokyo, and Paris?"));

boolean isToolCall = false;

ChatResponse chatResponse = null;

var chatClient = ChatClient.builder(this.chatModel).build();

do {

chatResponse = chatClient.prompt()
.messages(messages)
.functions(this.functionDefinition)
.options(OpenAiChatOptions.builder().withProxyToolCalls(true).build())
.call()
.chatResponse();

// Note that the tool call check could be platform specific because the finish
// reasons.
isToolCall = this.toolCallHelper.isToolCall(chatResponse,
Set.of(OpenAiApi.ChatCompletionFinishReason.TOOL_CALLS.name(),
OpenAiApi.ChatCompletionFinishReason.STOP.name()));

if (isToolCall) {

Optional<Generation> toolCallGeneration = chatResponse.getResults()
.stream()
.filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls()))
.findFirst();

assertThat(toolCallGeneration).isNotEmpty();

AssistantMessage assistantMessage = toolCallGeneration.get().getOutput();

List<ToolResponseMessage.ToolResponse> toolResponses = new ArrayList<>();

for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {

var functionName = toolCall.name();

assertThat(functionName).isEqualTo("getWeatherInLocation");

String functionArguments = toolCall.arguments();

@SuppressWarnings("unchecked")
Map<String, String> argumentsMap = new ObjectMapper().readValue(functionArguments, Map.class);

String functionResponse = getWeatherInLocation(argumentsMap.get("location").toString(),
argumentsMap.get("unit").toString());

toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), functionName,
ModelOptionsUtils.toJsonString(functionResponse)));
}

ToolResponseMessage toolMessageResponse = new ToolResponseMessage(toolResponses, Map.of());

messages = this.toolCallHelper.buildToolCallConversation(messages, assistantMessage,
toolMessageResponse);

assertThat(messages).isNotEmpty();

// prompt = new Prompt(toolCallConversation, prompt.getOptions());
}
}
while (isToolCall);

logger.info("Response: {}", chatResponse);

assertThat(chatResponse.getResult().getOutput().getContent()).contains("30", "10", "15");
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ public record AdvisedRequest(

public AdvisedRequest {
Assert.notNull(chatModel, "chatModel cannot be null");
Assert.hasText(userText, "userText cannot be null or empty");
Assert.isTrue(StringUtils.hasText(userText) || !CollectionUtils.isEmpty(messages),
"userText cannot be null or empty unless messages are provided and contain Tool Response message.");
Assert.notNull(media, "media cannot be null");
Assert.noNullElements(media, "media cannot contain null elements");
Assert.notNull(functionNames, "functionNames cannot be null");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,17 @@ void whenUserTextIsNullThenThrows() {
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), null, null, null, List.of(), List.of(),
List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of()))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("userText cannot be null or empty");
.hasMessage(
"userText cannot be null or empty unless messages are provided and contain Tool Response message.");
}

@Test
void whenUserTextIsEmptyThenThrows() {
assertThatThrownBy(() -> new AdvisedRequest(mock(ChatModel.class), "", null, null, List.of(), List.of(),
List.of(), List.of(), Map.of(), Map.of(), List.of(), Map.of(), Map.of(), Map.of()))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage("userText cannot be null or empty");
.hasMessage(
"userText cannot be null or empty unless messages are provided and contain Tool Response message.");
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package org.springframework.ai.autoconfigure.bedrock.converse;

import org.springframework.ai.bedrock.converse.BedrockProxyChatModel;
import org.springframework.ai.model.function.FunctionCallingOptionsBuilder.PortableFunctionCallingOptions;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
Expand All @@ -38,12 +37,6 @@ public class BedrockConverseProxyChatProperties {
*/
private boolean enabled = true;

/**
* The generative id to use. See the {@link BedrockProxyChatModel} for the supported
* models.
*/
private String model = "anthropic.claude-3-5-sonnet-20240620-v1:0";

@NestedConfigurationProperty
private PortableFunctionCallingOptions options = PortableFunctionCallingOptions.builder()
.withTemperature(0.7)
Expand All @@ -59,14 +52,6 @@ public void setEnabled(boolean enabled) {
this.enabled = enabled;
}

public String getModel() {
return this.model;
}

public void setModel(String model) {
this.model = model;
}

public PortableFunctionCallingOptions getOptions() {
return this.options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,12 @@ public void chatCompletionDisabled() {
.run(context -> assertThat(context.getBeansOfType(BedrockConverseProxyChatProperties.class)).isNotEmpty());

// Explicitly enable the chat auto-configuration.
new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.converse.chat..enabled=true")
new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.converse.chat.enabled=true")
.withConfiguration(AutoConfigurations.of(BedrockConverseProxyChatAutoConfiguration.class))
.run(context -> assertThat(context.getBeansOfType(BedrockConverseProxyChatProperties.class)).isNotEmpty());

// Explicitly disable the chat auto-configuration.
new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.converse.chat..enabled=false")
new ApplicationContextRunner().withPropertyValues("spring.ai.bedrock.converse.chat.enabled=false")
.withConfiguration(AutoConfigurations.of(BedrockConverseProxyChatAutoConfiguration.class))
.run(context -> assertThat(context.getBeansOfType(BedrockConverseProxyChatProperties.class)).isEmpty());
}
Expand Down
Loading