From 70cb780f011603fd4471616dad6d2e1a3e0efcfa Mon Sep 17 00:00:00 2001 From: Lorenzo Caenazzo Date: Thu, 2 May 2024 15:08:28 +0200 Subject: [PATCH] :adhesive_bandage: add google search retrieval as builtin tool --- .../gemini/VertexAiGeminiChatClient.java | 25 ++-- .../gemini/VertexAiGeminiChatOptions.java | 112 ++++++------------ 2 files changed, 48 insertions(+), 89 deletions(-) diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java index ad74073fa7..de2dd504d3 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatClient.java @@ -24,15 +24,7 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.google.cloud.vertexai.VertexAI; -import com.google.cloud.vertexai.api.Content; -import com.google.cloud.vertexai.api.FunctionCall; -import com.google.cloud.vertexai.api.FunctionDeclaration; -import com.google.cloud.vertexai.api.FunctionResponse; -import com.google.cloud.vertexai.api.GenerateContentResponse; -import com.google.cloud.vertexai.api.GenerationConfig; -import com.google.cloud.vertexai.api.Part; -import com.google.cloud.vertexai.api.Schema; -import com.google.cloud.vertexai.api.Tool; +import com.google.cloud.vertexai.api.*; import com.google.cloud.vertexai.generativeai.GenerativeModel; import com.google.cloud.vertexai.generativeai.PartMaker; import com.google.cloud.vertexai.generativeai.ResponseStream; @@ -243,8 +235,19 @@ private GeminiRequest createGeminiRequest(Prompt prompt) { } // Add the enabled functions definitions to the request's tools parameter. + + List tools = new ArrayList<>(); if (!CollectionUtils.isEmpty(functionsForThisRequest)) { - List tools = this.getFunctionTools(functionsForThisRequest); + tools.addAll(this.getFunctionTools(functionsForThisRequest)); + } + if (((VertexAiGeminiChatOptions) prompt.getOptions()).getGoogleSearchRetrieval()) { + final var googleSearchRetrieval = GoogleSearchRetrieval.newBuilder().getDefaultInstanceForType(); + final var googleSearchRetrievalTool = Tool.newBuilder() + .setGoogleSearchRetrieval(googleSearchRetrieval) + .build(); + tools.add(googleSearchRetrievalTool); + } + if (!CollectionUtils.isEmpty(tools)) { generativeModelBuilder.setTools(tools); } @@ -351,6 +354,8 @@ private List getFunctionTools(Set functionNames) { final var tool = Tool.newBuilder(); + tool.setGoogleSearchRetrieval(GoogleSearchRetrieval.newBuilder().getDefaultInstanceForType()); + final List functionDeclarations = this.resolveFunctionCallbacks(functionNames) .stream() .map(functionCallback -> FunctionDeclaration.newBuilder() diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java index 7d4e9875a5..44ef0cfe6f 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatOptions.java @@ -15,10 +15,7 @@ */ package org.springframework.ai.vertexai.gemini; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; +import java.util.*; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonInclude; @@ -99,6 +96,12 @@ public enum TransportType { @JsonIgnore private Set functions = new HashSet<>(); + /** + * Use Google search Grounding feature + */ + @JsonIgnore + private boolean googleSearchRetrieval = false; + // @formatter:on public static Builder builder() { @@ -161,6 +164,11 @@ public Builder withFunction(String functionName) { return this; } + public Builder withGoogleSearchRetrieval(boolean googleSearch) { + this.options.googleSearchRetrieval = googleSearch; + return this; + } + public VertexAiGeminiChatOptions build() { return this.options; } @@ -248,86 +256,32 @@ public void setFunctions(Set functions) { this.functions = functions; } - @Override - public int hashCode() { - final int prime = 31; - int result = 1; - result = prime * result + ((stopSequences == null) ? 0 : stopSequences.hashCode()); - result = prime * result + ((temperature == null) ? 0 : temperature.hashCode()); - result = prime * result + ((topP == null) ? 0 : topP.hashCode()); - result = prime * result + ((topK == null) ? 0 : topK.hashCode()); - result = prime * result + ((candidateCount == null) ? 0 : candidateCount.hashCode()); - result = prime * result + ((maxOutputTokens == null) ? 0 : maxOutputTokens.hashCode()); - result = prime * result + ((model == null) ? 0 : model.hashCode()); - result = prime * result + ((functionCallbacks == null) ? 0 : functionCallbacks.hashCode()); - result = prime * result + ((functions == null) ? 0 : functions.hashCode()); - return result; + public boolean getGoogleSearchRetrieval() { + return this.googleSearchRetrieval; + } + + public void setGoogleSearchRetrieval(boolean googleSearchRetrieval) { + this.googleSearchRetrieval = googleSearchRetrieval; } @Override - public boolean equals(Object obj) { - if (this == obj) + public boolean equals(Object o) { + if (this == o) return true; - if (obj == null) - return false; - if (getClass() != obj.getClass()) - return false; - VertexAiGeminiChatOptions other = (VertexAiGeminiChatOptions) obj; - if (stopSequences == null) { - if (other.stopSequences != null) - return false; - } - else if (!stopSequences.equals(other.stopSequences)) - return false; - if (temperature == null) { - if (other.temperature != null) - return false; - } - else if (!temperature.equals(other.temperature)) - return false; - if (topP == null) { - if (other.topP != null) - return false; - } - else if (!topP.equals(other.topP)) - return false; - if (topK == null) { - if (other.topK != null) - return false; - } - else if (!topK.equals(other.topK)) - return false; - if (candidateCount == null) { - if (other.candidateCount != null) - return false; - } - else if (!candidateCount.equals(other.candidateCount)) - return false; - if (maxOutputTokens == null) { - if (other.maxOutputTokens != null) - return false; - } - else if (!maxOutputTokens.equals(other.maxOutputTokens)) - return false; - if (model == null) { - if (other.model != null) - return false; - } - else if (!model.equals(other.model)) + if (!(o instanceof VertexAiGeminiChatOptions that)) return false; - if (functionCallbacks == null) { - if (other.functionCallbacks != null) - return false; - } - else if (!functionCallbacks.equals(other.functionCallbacks)) - return false; - if (functions == null) { - if (other.functions != null) - return false; - } - else if (!functions.equals(other.functions)) - return false; - return true; + return googleSearchRetrieval == that.googleSearchRetrieval && Objects.equals(stopSequences, that.stopSequences) + && Objects.equals(temperature, that.temperature) && Objects.equals(topP, that.topP) + && Objects.equals(topK, that.topK) && Objects.equals(candidateCount, that.candidateCount) + && Objects.equals(maxOutputTokens, that.maxOutputTokens) && Objects.equals(model, that.model) + && Objects.equals(functionCallbacks, that.functionCallbacks) + && Objects.equals(functions, that.functions); + } + + @Override + public int hashCode() { + return Objects.hash(stopSequences, temperature, topP, topK, candidateCount, maxOutputTokens, model, + functionCallbacks, functions, googleSearchRetrieval); } }