Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,11 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions,

public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager,
ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) {
super(null, defaultOptions, List.of());
// We do not pass the 'defaultOptions' to the AbstractToolSupport, because it
// modifies them.
// We are not using the AbstractToolSupport class in this path, so we just pass
// empty options.
super(null, OllamaOptions.builder().build(), List.of());
Assert.notNull(ollamaApi, "ollamaApi must not be null");
Assert.notNull(defaultOptions, "defaultOptions must not be null");
Assert.notNull(toolCallingManager, "toolCallingManager must not be null");
Expand Down Expand Up @@ -395,17 +399,24 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOp
// Define request options by merging runtime options and default options
OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions,
OllamaOptions.class);
// Merge tool names and tool callbacks explicitly since they are ignored by
// Merge @JsonIgnore-annotated options explicitly since they are ignored by
// Jackson, used by ModelOptionsUtils.
if (runtimeOptions != null) {
requestOptions.setInternalToolExecutionEnabled(
ModelOptionsUtils.mergeOption(runtimeOptions.isInternalToolExecutionEnabled(),
this.defaultOptions.isInternalToolExecutionEnabled()));
requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(),
this.defaultOptions.getToolNames()));
requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(),
this.defaultOptions.getToolCallbacks()));
requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(),
this.defaultOptions.getToolContext()));
}
else {
requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.isInternalToolExecutionEnabled());
requestOptions.setToolNames(this.defaultOptions.getToolNames());
requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks());
requestOptions.setToolContext(this.defaultOptions.getToolContext());
}

// Validate request options
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -331,7 +332,7 @@ public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions {
private Set<String> toolNames = new HashSet<>();

@JsonIgnore
private Map<String, Object> toolContext;
private Map<String, Object> toolContext = new HashMap<>();

public static Builder builder() {
return new Builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,13 @@

import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.model.tool.ToolCallingChatOptions;
import org.springframework.ai.ollama.api.OllamaApi;
import org.springframework.ai.ollama.api.OllamaOptions;
import org.springframework.ai.tool.ToolCallback;
import org.springframework.ai.tool.definition.ToolDefinition;

import java.util.Map;

import static org.assertj.core.api.Assertions.assertThat;

Expand All @@ -36,6 +41,37 @@ class OllamaChatRequestTests {
.defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build())
.build();

@Test
void whenToolRuntimeOptionsThenMergeWithDefaults() {
OllamaOptions defaultOptions = OllamaOptions.builder()
.model("MODEL_NAME")
.internalToolExecutionEnabled(true)
.toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2"))
.toolNames("tool1", "tool2")
.toolContext(Map.of("key1", "value1"))
.build();
OllamaChatModel chatModel = OllamaChatModel.builder()
.ollamaApi(new OllamaApi())
.defaultOptions(defaultOptions)
.build();

OllamaOptions runtimeOptions = OllamaOptions.builder()
.internalToolExecutionEnabled(false)
.toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4"))
.toolNames("tool3")
.toolContext(Map.of("key2", "value2"))
.build();
Prompt prompt = chatModel.buildRequestPrompt(new Prompt("Test message content", runtimeOptions));

assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull();
assertThat(((ToolCallingChatOptions) prompt.getOptions()).isInternalToolExecutionEnabled()).isFalse();
assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(4);
assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool1",
"tool2", "tool3");
assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1")
.containsEntry("key2", "value2");
}

@Test
void createRequestWithDefaultOptions() {
var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content"));
Expand Down Expand Up @@ -124,4 +160,24 @@ public void createRequestWithDefaultOptionsModelOverride() {
assertThat(request.model()).isEqualTo("PROMPT_MODEL");
}

static class TestToolCallback implements ToolCallback {

private final ToolDefinition toolDefinition;

public TestToolCallback(String name) {
this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build();
}

@Override
public ToolDefinition getToolDefinition() {
return toolDefinition;
}

@Override
public String call(String toolInput) {
return "Mission accomplished!";
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.springframework.util.Assert;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -204,4 +205,15 @@ static List<FunctionCallback> mergeToolCallbacks(List<FunctionCallback> runtimeT
return mergedToolCallbacks;
}

static Map<String, Object> mergeToolContext(Map<String, Object> runtimeToolContext,
Map<String, Object> defaultToolContext) {
Assert.notNull(runtimeToolContext, "runtimeToolContext cannot be null");
Assert.noNullElements(runtimeToolContext.keySet(), "runtimeToolContext keys cannot be null");
Assert.notNull(defaultToolContext, "defaultToolContext cannot be null");
Assert.noNullElements(defaultToolContext.keySet(), "defaultToolContext keys cannot be null");
var mergedToolContext = new HashMap<>(defaultToolContext);
mergedToolContext.putAll(runtimeToolContext);
return mergedToolContext;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.springframework.ai.tool.definition.ToolDefinition;

import java.util.List;
import java.util.Map;
import java.util.Set;

import static org.assertj.core.api.Assertions.assertThat;
Expand Down Expand Up @@ -141,6 +142,47 @@ void whenMergeEmptyRuntimeAndEmptyDefaultToolCallbacks() {
assertThat(mergedToolCallbacks).hasSize(0);
}

@Test
void whenMergeRuntimeAndDefaultToolContext() {
Map<String, Object> runtimeToolContext = Map.of("key1", "value1", "key2", "value2");
Map<String, Object> defaultToolContext = Map.of("key1", "valueA", "key3", "value3");
Map<String, Object> mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext,
defaultToolContext);
assertThat(mergedToolContext).hasSize(3);
assertThat(mergedToolContext).containsEntry("key1", "value1")
.containsEntry("key2", "value2")
.containsEntry("key3", "value3");
}

@Test
void whenMergeRuntimeAndEmptyDefaultToolContext() {
Map<String, Object> runtimeToolContext = Map.of("key1", "value1", "key2", "value2");
Map<String, Object> defaultToolContext = Map.of();
Map<String, Object> mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext,
defaultToolContext);
assertThat(mergedToolContext).hasSize(2);
assertThat(mergedToolContext).containsEntry("key1", "value1").containsEntry("key2", "value2");
}

@Test
void whenMergeEmptyRuntimeAndDefaultToolContext() {
Map<String, Object> runtimeToolContext = Map.of();
Map<String, Object> defaultToolContext = Map.of("key1", "value1", "key2", "value2");
Map<String, Object> mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext,
defaultToolContext);
assertThat(mergedToolContext).hasSize(2);
assertThat(mergedToolContext).containsEntry("key1", "value1").containsEntry("key2", "value2");
}

@Test
void whenMergeEmptyRuntimeAndEmptyDefaultToolContext() {
Map<String, Object> runtimeToolContext = Map.of();
Map<String, Object> defaultToolContext = Map.of();
Map<String, Object> mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext,
defaultToolContext);
assertThat(mergedToolContext).hasSize(0);
}

static class TestToolCallback implements ToolCallback {

private final ToolDefinition toolDefinition;
Expand Down