Skip to content

Commit

Permalink
🩹 add google search retrieval as builtin tool
Browse files Browse the repository at this point in the history
  • Loading branch information
Grogdunn committed May 13, 2024
1 parent d7dad6e commit 70cb780
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,7 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.google.cloud.vertexai.VertexAI;
import com.google.cloud.vertexai.api.Content;
import com.google.cloud.vertexai.api.FunctionCall;
import com.google.cloud.vertexai.api.FunctionDeclaration;
import com.google.cloud.vertexai.api.FunctionResponse;
import com.google.cloud.vertexai.api.GenerateContentResponse;
import com.google.cloud.vertexai.api.GenerationConfig;
import com.google.cloud.vertexai.api.Part;
import com.google.cloud.vertexai.api.Schema;
import com.google.cloud.vertexai.api.Tool;
import com.google.cloud.vertexai.api.*;
import com.google.cloud.vertexai.generativeai.GenerativeModel;
import com.google.cloud.vertexai.generativeai.PartMaker;
import com.google.cloud.vertexai.generativeai.ResponseStream;
Expand Down Expand Up @@ -243,8 +235,19 @@ private GeminiRequest createGeminiRequest(Prompt prompt) {
}

// Add the enabled functions definitions to the request's tools parameter.

List<Tool> tools = new ArrayList<>();
if (!CollectionUtils.isEmpty(functionsForThisRequest)) {
List<Tool> tools = this.getFunctionTools(functionsForThisRequest);
tools.addAll(this.getFunctionTools(functionsForThisRequest));
}
if (((VertexAiGeminiChatOptions) prompt.getOptions()).getGoogleSearchRetrieval()) {
final var googleSearchRetrieval = GoogleSearchRetrieval.newBuilder().getDefaultInstanceForType();
final var googleSearchRetrievalTool = Tool.newBuilder()
.setGoogleSearchRetrieval(googleSearchRetrieval)
.build();
tools.add(googleSearchRetrievalTool);
}
if (!CollectionUtils.isEmpty(tools)) {
generativeModelBuilder.setTools(tools);
}

Expand Down Expand Up @@ -351,6 +354,8 @@ private List<Tool> getFunctionTools(Set<String> functionNames) {

final var tool = Tool.newBuilder();

tool.setGoogleSearchRetrieval(GoogleSearchRetrieval.newBuilder().getDefaultInstanceForType());

final List<FunctionDeclaration> functionDeclarations = this.resolveFunctionCallbacks(functionNames)
.stream()
.map(functionCallback -> FunctionDeclaration.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,7 @@
*/
package org.springframework.ai.vertexai.gemini;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.*;

import com.fasterxml.jackson.annotation.JsonIgnore;
import com.fasterxml.jackson.annotation.JsonInclude;
Expand Down Expand Up @@ -99,6 +96,12 @@ public enum TransportType {
@JsonIgnore
private Set<String> functions = new HashSet<>();

/**
* Use Google search Grounding feature
*/
@JsonIgnore
private boolean googleSearchRetrieval = false;

// @formatter:on

public static Builder builder() {
Expand Down Expand Up @@ -161,6 +164,11 @@ public Builder withFunction(String functionName) {
return this;
}

public Builder withGoogleSearchRetrieval(boolean googleSearch) {
this.options.googleSearchRetrieval = googleSearch;
return this;
}

public VertexAiGeminiChatOptions build() {
return this.options;
}
Expand Down Expand Up @@ -248,86 +256,32 @@ public void setFunctions(Set<String> functions) {
this.functions = functions;
}

@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((stopSequences == null) ? 0 : stopSequences.hashCode());
result = prime * result + ((temperature == null) ? 0 : temperature.hashCode());
result = prime * result + ((topP == null) ? 0 : topP.hashCode());
result = prime * result + ((topK == null) ? 0 : topK.hashCode());
result = prime * result + ((candidateCount == null) ? 0 : candidateCount.hashCode());
result = prime * result + ((maxOutputTokens == null) ? 0 : maxOutputTokens.hashCode());
result = prime * result + ((model == null) ? 0 : model.hashCode());
result = prime * result + ((functionCallbacks == null) ? 0 : functionCallbacks.hashCode());
result = prime * result + ((functions == null) ? 0 : functions.hashCode());
return result;
public boolean getGoogleSearchRetrieval() {
return this.googleSearchRetrieval;
}

public void setGoogleSearchRetrieval(boolean googleSearchRetrieval) {
this.googleSearchRetrieval = googleSearchRetrieval;
}

@Override
public boolean equals(Object obj) {
if (this == obj)
public boolean equals(Object o) {
if (this == o)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
VertexAiGeminiChatOptions other = (VertexAiGeminiChatOptions) obj;
if (stopSequences == null) {
if (other.stopSequences != null)
return false;
}
else if (!stopSequences.equals(other.stopSequences))
return false;
if (temperature == null) {
if (other.temperature != null)
return false;
}
else if (!temperature.equals(other.temperature))
return false;
if (topP == null) {
if (other.topP != null)
return false;
}
else if (!topP.equals(other.topP))
return false;
if (topK == null) {
if (other.topK != null)
return false;
}
else if (!topK.equals(other.topK))
return false;
if (candidateCount == null) {
if (other.candidateCount != null)
return false;
}
else if (!candidateCount.equals(other.candidateCount))
return false;
if (maxOutputTokens == null) {
if (other.maxOutputTokens != null)
return false;
}
else if (!maxOutputTokens.equals(other.maxOutputTokens))
return false;
if (model == null) {
if (other.model != null)
return false;
}
else if (!model.equals(other.model))
if (!(o instanceof VertexAiGeminiChatOptions that))
return false;
if (functionCallbacks == null) {
if (other.functionCallbacks != null)
return false;
}
else if (!functionCallbacks.equals(other.functionCallbacks))
return false;
if (functions == null) {
if (other.functions != null)
return false;
}
else if (!functions.equals(other.functions))
return false;
return true;
return googleSearchRetrieval == that.googleSearchRetrieval && Objects.equals(stopSequences, that.stopSequences)
&& Objects.equals(temperature, that.temperature) && Objects.equals(topP, that.topP)
&& Objects.equals(topK, that.topK) && Objects.equals(candidateCount, that.candidateCount)
&& Objects.equals(maxOutputTokens, that.maxOutputTokens) && Objects.equals(model, that.model)
&& Objects.equals(functionCallbacks, that.functionCallbacks)
&& Objects.equals(functions, that.functions);
}

@Override
public int hashCode() {
return Objects.hash(stopSequences, temperature, topP, topK, candidateCount, maxOutputTokens, model,
functionCallbacks, functions, googleSearchRetrieval);
}

}

0 comments on commit 70cb780

Please sign in to comment.