Skip to content

Commit

Permalink
🧑‍💻 better support for non-returning functions
Browse files Browse the repository at this point in the history
  • Loading branch information
Grogdunn committed May 6, 2024
1 parent 7ce90d0 commit 5948013
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 59 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
Expand All @@ -50,6 +51,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -390,10 +392,18 @@ public ChatCompletion build() {
}

@Override
protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseRequest(
ChatCompletionRequest previousRequest, RequestMessage responseMessage,
List<RequestMessage> conversationHistory) {
boolean needCompleteRoundTrip = false;
protected boolean hasReturningFunction(RequestMessage responseMessage) {
return responseMessage.content()
.stream()
.filter(c -> c.type() == MediaContent.Type.TOOL_USE)
.map(MediaContent::name)
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
}

@Override
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
RequestMessage responseMessage, List<RequestMessage> conversationHistory) {
List<MediaContent> toolToUseList = responseMessage.content()
.stream()
.filter(c -> c.type() == MediaContent.Type.TOOL_USE)
Expand All @@ -414,7 +424,6 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
String functionResponse = this.functionCallbackRegister.get(functionName)
.call(ModelOptionsUtils.toJsonString(functionArguments));
if (functionResponse != null) {
needCompleteRoundTrip = true;
toolResults.add(new MediaContent(Type.TOOL_RESULT, functionCallId, functionResponse));
}
}
Expand All @@ -425,7 +434,7 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
// Recursively call chatCompletionWithTools until the model doesn't call a
// functions anymore.
final var build = ChatCompletionRequest.from(previousRequest).withMessages(conversationHistory).build();
return new CompleteRoundTripBox<>(needCompleteRoundTrip, build);
return build;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
Expand All @@ -57,6 +58,7 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;

/**
Expand Down Expand Up @@ -426,11 +428,18 @@ private ChatCompletionsOptions merge(ChatCompletionsOptions fromOptions, ChatCom
}

@Override
protected CompleteRoundTripBox<ChatCompletionsOptions> doCreateToolResponseRequest(
ChatCompletionsOptions previousRequest, ChatRequestMessage responseMessage,
List<ChatRequestMessage> conversationHistory) {
protected boolean hasReturningFunction(ChatRequestMessage responseMessage) {
return ((ChatRequestAssistantMessage) responseMessage).getToolCalls()
.stream()
.map(toolCall -> ((ChatCompletionsFunctionToolCall) toolCall).getFunction().getName())
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
}

@Override
protected ChatCompletionsOptions doCreateToolResponseRequest(ChatCompletionsOptions previousRequest,
ChatRequestMessage responseMessage, List<ChatRequestMessage> conversationHistory) {

boolean needCompleteRoundTrip = false;
// Every tool-call item requires a separate function call and a response (TOOL)
// message.
for (ChatCompletionsToolCall toolCall : ((ChatRequestAssistantMessage) responseMessage).getToolCalls()) {
Expand All @@ -445,7 +454,6 @@ protected CompleteRoundTripBox<ChatCompletionsOptions> doCreateToolResponseReque
String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);

if (functionResponse != null) {
needCompleteRoundTrip = true;
// Add the function response to the conversation.
conversationHistory.add(new ChatRequestToolMessage(functionResponse, toolCall.getId()));
}
Expand All @@ -457,7 +465,7 @@ protected CompleteRoundTripBox<ChatCompletionsOptions> doCreateToolResponseReque

newRequest = merge(previousRequest, newRequest);

return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest);
return newRequest;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.springframework.ai.mistralai.api.MistralAiApi.ChatCompletionRequest;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
Expand Down Expand Up @@ -241,14 +242,18 @@ private List<MistralAiApi.FunctionTool> getFunctionTools(Set<String> functionNam
}).toList();
}

//
// Function Calling Support
//
@Override
protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseRequest(
ChatCompletionRequest previousRequest, ChatCompletionMessage responseMessage,
List<ChatCompletionMessage> conversationHistory) {
boolean needCompleteRoundTrip = false;
protected boolean hasReturningFunction(ChatCompletionMessage responseMessage) {
return responseMessage.toolCalls()
.stream()
.map(toolCall -> toolCall.function().name())
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
}

@Override
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {
// Every tool-call item requires a separate function call and a response (TOOL)
// message.
for (ToolCall toolCall : responseMessage.toolCalls()) {
Expand All @@ -262,7 +267,6 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques

String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
if (functionResponse != null) {
needCompleteRoundTrip = true;
// Add the function response to the conversation.
conversationHistory.add(new ChatCompletionMessage(functionResponse, ChatCompletionMessage.Role.TOOL,
functionName, null));
Expand All @@ -275,7 +279,7 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques
ChatCompletionRequest newRequest = new ChatCompletionRequest(conversationHistory, false);
newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);

return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest);
return newRequest;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
Expand Down Expand Up @@ -323,11 +324,18 @@ private List<OpenAiApi.FunctionTool> getFunctionTools(Set<String> functionNames)
}

@Override
protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseRequest(
ChatCompletionRequest previousRequest, ChatCompletionMessage responseMessage,
List<ChatCompletionMessage> conversationHistory) {
protected boolean hasReturningFunction(ChatCompletionMessage responseMessage) {
return responseMessage.toolCalls()
.stream()
.map(toolCall -> toolCall.function().name())
.map(functionName -> Optional.ofNullable(this.functionCallbackRegister.get(functionName)))
.anyMatch(functionCallback -> functionCallback.map(FunctionCallback::returningFunction).orElse(false));
}

@Override
protected ChatCompletionRequest doCreateToolResponseRequest(ChatCompletionRequest previousRequest,
ChatCompletionMessage responseMessage, List<ChatCompletionMessage> conversationHistory) {

boolean needCompleteRoundTrip = false;
// Every tool-call item requires a separate function call and a response (TOOL)
// message.
for (ToolCall toolCall : responseMessage.toolCalls()) {
Expand All @@ -341,7 +349,6 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques

String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
if (functionResponse != null) {
needCompleteRoundTrip = true;
// Add the function response to the conversation.
conversationHistory
.add(new ChatCompletionMessage(functionResponse, Role.TOOL, functionName, toolCall.id(), null));
Expand All @@ -354,7 +361,7 @@ protected CompleteRoundTripBox<ChatCompletionRequest> doCreateToolResponseReques

newRequest = ModelOptionsUtils.merge(newRequest, previousRequest, ChatCompletionRequest.class);

return new CompleteRoundTripBox<>(needCompleteRoundTrip, newRequest);
return newRequest;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.model.function.AbstractFunctionCallSupport;
import org.springframework.ai.model.function.FunctionCallback;
import org.springframework.ai.model.function.FunctionCallbackContext;
import org.springframework.ai.vertexai.gemini.metadata.VertexAiChatResponseMetadata;
import org.springframework.ai.vertexai.gemini.metadata.VertexAiUsage;
Expand All @@ -57,6 +58,7 @@
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

Expand Down Expand Up @@ -400,9 +402,16 @@ public void destroy() throws Exception {
}

@Override
protected CompleteRoundTripBox<GeminiRequest> doCreateToolResponseRequest(GeminiRequest previousRequest,
Content responseMessage, List<Content> conversationHistory) {
boolean needCompleteRoundTrip = false;
protected boolean hasReturningFunction(Content responseMessage) {
final var functionName = responseMessage.getPartsList().get(0).getFunctionCall().getName();
return Optional.ofNullable(this.functionCallbackRegister.get(functionName))
.map(FunctionCallback::returningFunction)
.orElse(false);
}

@Override
protected GeminiRequest doCreateToolResponseRequest(GeminiRequest previousRequest, Content responseMessage,
List<Content> conversationHistory) {
FunctionCall functionCall = responseMessage.getPartsList().iterator().next().getFunctionCall();

var functionName = functionCall.getName();
Expand All @@ -414,7 +423,6 @@ protected CompleteRoundTripBox<GeminiRequest> doCreateToolResponseRequest(Gemini

String functionResponse = this.functionCallbackRegister.get(functionName).call(functionArguments);
if (functionResponse != null) {
needCompleteRoundTrip = true;
Content contentFnResp = Content.newBuilder()
.addParts(Part.newBuilder()
.setFunctionResponse(FunctionResponse.newBuilder()
Expand All @@ -428,7 +436,7 @@ protected CompleteRoundTripBox<GeminiRequest> doCreateToolResponseRequest(Gemini
}

final var geminiRequest = new GeminiRequest(conversationHistory, previousRequest.model());
return new CompleteRoundTripBox<>(needCompleteRoundTrip, geminiRequest);
return geminiRequest;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,17 @@ protected Resp handleFunctionCallOrReturn(Req request, Resp response) {
// Add the assistant response to the message conversation history.
conversationHistory.add(responseMessage);

CompleteRoundTripBox<Req> needRoundTripAndResponse = this.doCreateToolResponseRequest(request, responseMessage,
conversationHistory);
if (!needRoundTripAndResponse.completeRoundTrip) {
Req newRequest = this.doCreateToolResponseRequest(request, responseMessage, conversationHistory);

if (!this.hasReturningFunction(responseMessage)) {
return response;
}
return this.callWithFunctionSupport(needRoundTripAndResponse.getResponseMessage());
return this.callWithFunctionSupport(newRequest);
}

abstract protected CompleteRoundTripBox<Req> doCreateToolResponseRequest(Req previousRequest, Msg responseMessage,
abstract protected boolean hasReturningFunction(Msg responseMessage);

abstract protected Req doCreateToolResponseRequest(Req previousRequest, Msg responseMessage,
List<Msg> conversationHistory);

abstract protected List<Msg> doGetUserMessages(Req request);
Expand All @@ -158,25 +160,4 @@ abstract protected CompleteRoundTripBox<Req> doCreateToolResponseRequest(Req pre

abstract protected boolean isToolFunctionCall(Resp response);

public static class CompleteRoundTripBox<Resp> {

private final boolean completeRoundTrip;

private final Resp responseMessage;

public CompleteRoundTripBox(boolean completeRoundTrip, Resp responseMessage) {
this.completeRoundTrip = completeRoundTrip;
this.responseMessage = responseMessage;
}

public Resp getResponseMessage() {
return responseMessage;
}

public boolean isCompleteRoundTrip() {
return completeRoundTrip;
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ public String getInputTypeSchema() {
return this.inputTypeSchema;
}

@Override
public boolean returningFunction() {
return !outputType.isAssignableFrom(Void.class);
}

@Override
public String call(String functionArguments) {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,9 @@ public interface FunctionCallback {
*/
public String call(String functionInput);

/**
* @return This function return a value or not
*/
boolean returningFunction();

}
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ public FunctionCallback getFunctionCallback(@NonNull String beanName, @Nullable
.withInputType(functionInputClass)
.build();
}
if (bean instanceof Consumer consumer) {
if (bean instanceof Consumer<?> consumer) {
return FunctionCallbackWrapper.builder(consumer)
.withName(functionName)
.withSchemaType(this.schemaType)
Expand Down

0 comments on commit 5948013

Please sign in to comment.