From 5948013c29b7d51eb531183be74c0cd15ea12f41 Mon Sep 17 00:00:00 2001 From: Lorenzo Caenazzo Date: Mon, 6 May 2024 09:52:33 +0200 Subject: [PATCH] :technologist: better support for non-returning functions --- .../ai/anthropic/AnthropicChatClient.java | 21 ++++++++---- .../azure/openai/AzureOpenAiChatClient.java | 20 +++++++---- .../ai/mistralai/MistralAiChatClient.java | 22 ++++++++----- .../ai/openai/OpenAiChatClient.java | 19 +++++++---- .../gemini/VertexAiGeminiChatClient.java | 18 +++++++--- .../function/AbstractFunctionCallSupport.java | 33 ++++--------------- .../function/AbstractFunctionCallback.java | 5 +++ .../ai/model/function/FunctionCallback.java | 5 +++ .../function/FunctionCallbackContext.java | 2 +- 9 files changed, 86 insertions(+), 59 deletions(-) diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java index 8c0a8a0349..c8b3a64acd 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatClient.java @@ -37,6 +37,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; @@ -50,6 +51,7 @@ import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -390,10 +392,18 @@ public ChatCompletion build() { } @Override - protected CompleteRoundTripBox doCreateToolResponseRequest( - ChatCompletionRequest previousRequest, RequestMessage responseMessage, - List conversationHistory) { - boolean needCompleteRoundTrip = false; + protected boolean hasReturningFunction(RequestMessage responseMessage) { + return responseMessage.content() + .stream() + .filter(c -> c.type() == MediaContent.Type.TOOL_USE) + .map(MediaContent::name) + .map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName))) + .anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false)); + } + + @Override + protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest, + RequestMessage responseMessage, List conversationHistory) { List toolToUseList = responseMessage.content() .stream() .filter(c -> c.type() == MediaContent.Type.TOOL_USE) @@ -414,7 +424,6 @@ protected CompleteRoundTripBox doCreateToolResponseReques String functionResponse = this.functionCallbackRegister.get(functionName) .call(ModelOptionsUtils.toJsonString(functionArguments)); if (functionResponse != null) { - needCompleteRoundTrip = true; toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse)); } } @@ -425,7 +434,7 @@ protected CompleteRoundTripBox doCreateToolResponseReques // Recursively call chatCompletionWithTools until the model doesn't call a // functions anymore. final var build = ChatCompletionRequest.from(previousRequest).withMessages(conversationHistory).build(); - return new CompleteRoundTripBox<>(needCompleteRoundTrip, build); + return build; } @Override diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java index ce3070d97f..e51c5c8f57 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatClient.java @@ -49,6 +49,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -57,6 +58,7 @@ import java.util.Collections; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; /** @@ -426,11 +428,18 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCom } @Override - protected CompleteRoundTripBox doCreateToolResponseRequest( - ChatCompletionsOptions previousRequest, ChatRequestMessage responseMessage, - List conversationHistory) { + protected boolean hasReturningFunction(ChatRequestMessage responseMessage) { + return ((ChatRequestAssistantMessage) responseMessage).getToolCalls() + .stream() + .map(toolCall -> ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getName()) + .map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName))) + .anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false)); + } + + @Override + protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest, + ChatRequestMessage responseMessage, List conversationHistory) { - boolean needCompleteRoundTrip = false; // Every tool-call item requires a separate function call and a response (TOOL) // message. for (ChatCompletionsToolCall toolCall : ((ChatRequestAssistantMessage) responseMessage).getToolCalls()) { @@ -445,7 +454,6 @@ protected CompleteRoundTripBox doCreateToolResponseReque String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); if (functionResponse != null) { - needCompleteRoundTrip = true; // Add the function response to the conversation. conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId())); } @@ -457,7 +465,7 @@ protected CompleteRoundTripBox doCreateToolResponseReque newRequest = merge(previousRequest, newRequest); - return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest); + return newRequest; } @Override diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java index 46965acbdd..d2008281b3 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatClient.java @@ -33,6 +33,7 @@ import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; @@ -241,14 +242,18 @@ private List getFunctionTools(Set functionNam }).toList(); } - // - // Function Calling Support - // @Override - protected CompleteRoundTripBox doCreateToolResponseRequest( - ChatCompletionRequest previousRequest, ChatCompletionMessage responseMessage, - List conversationHistory) { - boolean needCompleteRoundTrip = false; + protected boolean hasReturningFunction(ChatCompletionMessage responseMessage) { + return responseMessage.toolCalls() + .stream() + .map(toolCall -> toolCall.function().name()) + .map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName))) + .anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false)); + } + + @Override + protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest, + ChatCompletionMessage responseMessage, List conversationHistory) { // Every tool-call item requires a separate function call and a response (TOOL) // message. for (ToolCall toolCall : responseMessage.toolCalls()) { @@ -262,7 +267,6 @@ protected CompleteRoundTripBox doCreateToolResponseReques String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); if (functionResponse != null) { - needCompleteRoundTrip = true; // Add the function response to the conversation. conversationHistory.add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL, functionName, null)); @@ -275,7 +279,7 @@ protected CompleteRoundTripBox doCreateToolResponseReques ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, false); newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class); - return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest); + return newRequest; } @Override diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java index b82e59d33f..65bb042307 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java @@ -27,6 +27,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; @@ -323,11 +324,18 @@ private List getFunctionTools(Set functionNames) } @Override - protected CompleteRoundTripBox doCreateToolResponseRequest( - ChatCompletionRequest previousRequest, ChatCompletionMessage responseMessage, - List conversationHistory) { + protected boolean hasReturningFunction(ChatCompletionMessage responseMessage) { + return responseMessage.toolCalls() + .stream() + .map(toolCall -> toolCall.function().name()) + .map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName))) + .anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false)); + } + + @Override + protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest, + ChatCompletionMessage responseMessage, List conversationHistory) { - boolean needCompleteRoundTrip = false; // Every tool-call item requires a separate function call and a response (TOOL) // message. for (ToolCall toolCall : responseMessage.toolCalls()) { @@ -341,7 +349,6 @@ protected CompleteRoundTripBox doCreateToolResponseReques String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); if (functionResponse != null) { - needCompleteRoundTrip = true; // Add the function response to the conversation. conversationHistory .add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null)); @@ -354,7 +361,7 @@ protected CompleteRoundTripBox doCreateToolResponseReques newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class); - return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest); + return newRequest; } @Override diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java index 016f58ae97..395511616a 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java @@ -44,6 +44,7 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.function.AbstractFunctionCallSupport; +import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; import org.springframework.ai.vertexai.gemini.metadata.VertexAiChatResponseMetadata; import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage; @@ -57,6 +58,7 @@ import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -400,9 +402,16 @@ public void destroy() throws Exception { } @Override - protected CompleteRoundTripBox doCreateToolResponseRequest(GeminiRequest previousRequest, - Content responseMessage, List conversationHistory) { - boolean needCompleteRoundTrip = false; + protected boolean hasReturningFunction(Content responseMessage) { + final var functionName = responseMessage.getPartsList().get(0).getFunctionCall().getName(); + return Optional.ofNullable(this.functionCallbackRegister.get(functionName)) + .map(FunctionCallback::returningFunction) + .orElse(false); + } + + @Override + protected GeminiRequest doCreateToolResponseRequest(GeminiRequest previousRequest, Content responseMessage, + List conversationHistory) { FunctionCall functionCall = responseMessage.getPartsList().iterator().next().getFunctionCall(); var functionName = functionCall.getName(); @@ -414,7 +423,6 @@ protected CompleteRoundTripBox doCreateToolResponseRequest(Gemini String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments); if (functionResponse != null) { - needCompleteRoundTrip = true; Content contentFnResp = Content.newBuilder() .addParts(Part.newBuilder() .setFunctionResponse(FunctionResponse.newBuilder() @@ -428,7 +436,7 @@ protected CompleteRoundTripBox doCreateToolResponseRequest(Gemini } final var geminiRequest = new GeminiRequest(conversationHistory, previousRequest.model()); - return new CompleteRoundTripBox<>(needCompleteRoundTrip, geminiRequest); + return geminiRequest; } @Override diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java index 4036a5a9a4..b9209e135d 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallSupport.java @@ -139,15 +139,17 @@ protected Resp handleFunctionCallOrReturn(Req request, Resp response) { // Add the assistant response to the message conversation history. conversationHistory.add(responseMessage); - CompleteRoundTripBox needRoundTripAndResponse = this.doCreateToolResponseRequest(request, responseMessage, - conversationHistory); - if (!needRoundTripAndResponse.completeRoundTrip) { + Req newRequest = this.doCreateToolResponseRequest(request, responseMessage, conversationHistory); + + if (!this.hasReturningFunction(responseMessage)) { return response; } - return this.callWithFunctionSupport(needRoundTripAndResponse.getResponseMessage()); + return this.callWithFunctionSupport(newRequest); } - abstract protected CompleteRoundTripBox doCreateToolResponseRequest(Req previousRequest, Msg responseMessage, + abstract protected boolean hasReturningFunction(Msg responseMessage); + + abstract protected Req doCreateToolResponseRequest(Req previousRequest, Msg responseMessage, List conversationHistory); abstract protected List doGetUserMessages(Req request); @@ -158,25 +160,4 @@ abstract protected CompleteRoundTripBox doCreateToolResponseRequest(Req pre abstract protected boolean isToolFunctionCall(Resp response); - public static class CompleteRoundTripBox { - - private final boolean completeRoundTrip; - - private final Resp responseMessage; - - public CompleteRoundTripBox(boolean completeRoundTrip, Resp responseMessage) { - this.completeRoundTrip = completeRoundTrip; - this.responseMessage = responseMessage; - } - - public Resp getResponseMessage() { - return responseMessage; - } - - public boolean isCompleteRoundTrip() { - return completeRoundTrip; - } - - } - } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java index 2a4c2a9d55..e5fb5f09a4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/AbstractFunctionCallback.java @@ -100,6 +100,11 @@ public String getInputTypeSchema() { return this.inputTypeSchema; } + @Override + public boolean returningFunction() { + return !outputType.isAssignableFrom(Void.class); + } + @Override public String call(String functionArguments) { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java index 957b4e2a63..c30f0adcc8 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallback.java @@ -49,4 +49,9 @@ public interface FunctionCallback { */ public String call(String functionInput); + /** + * @return This function return a value or not + */ + boolean returningFunction(); + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java index 4c9cc535b0..5f8f883e34 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/function/FunctionCallbackContext.java @@ -119,7 +119,7 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable .withInputType(functionInputClass) .build(); } - if (bean instanceof Consumer consumer) { + if (bean instanceof Consumer consumer) { return FunctionCallbackWrapper.builder(consumer) .withName(functionName) .withSchemaType(this.schemaType)