From 7f7b8baa4240be25d4fb2b405af073bc5b4d83c9 Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Thu, 6 Feb 2025 07:25:03 +0100 Subject: [PATCH] Add merge for missing Ollama options When fields in OllamaOptions are marked as ignored in Jackson, they require explicit merge of runtime and default options. Added tests to validate the different merge combinations for all tool-related options. Signed-off-by: Thomas Vitale --- .../ai/ollama/OllamaChatModel.java | 15 ++++- .../ai/ollama/api/OllamaOptions.java | 3 +- .../ai/ollama/OllamaChatRequestTests.java | 56 +++++++++++++++++++ .../ai/model/tool/ToolCallingChatOptions.java | 12 ++++ .../tool/ToolCallingChatOptionsTests.java | 42 ++++++++++++++ 5 files changed, 125 insertions(+), 3 deletions(-) diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index b5743d80bfd..6dcb48d28f2 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -147,7 +147,11 @@ public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, public OllamaChatModel(OllamaApi ollamaApi, OllamaOptions defaultOptions, ToolCallingManager toolCallingManager, ObservationRegistry observationRegistry, ModelManagementOptions modelManagementOptions) { - super(null, defaultOptions, List.of()); + // We do not pass the 'defaultOptions' to the AbstractToolSupport, because it + // modifies them. + // We are not using the AbstractToolSupport class in this path, so we just pass + // empty options. + super(null, OllamaOptions.builder().build(), List.of()); Assert.notNull(ollamaApi, "ollamaApi must not be null"); Assert.notNull(defaultOptions, "defaultOptions must not be null"); Assert.notNull(toolCallingManager, "toolCallingManager must not be null"); @@ -395,17 +399,24 @@ else if (prompt.getOptions() instanceof FunctionCallingOptions functionCallingOp // Define request options by merging runtime options and default options OllamaOptions requestOptions = ModelOptionsUtils.merge(runtimeOptions, this.defaultOptions, OllamaOptions.class); - // Merge tool names and tool callbacks explicitly since they are ignored by + // Merge @JsonIgnore-annotated options explicitly since they are ignored by // Jackson, used by ModelOptionsUtils. if (runtimeOptions != null) { + requestOptions.setInternalToolExecutionEnabled( + ModelOptionsUtils.mergeOption(runtimeOptions.isInternalToolExecutionEnabled(), + this.defaultOptions.isInternalToolExecutionEnabled())); requestOptions.setToolNames(ToolCallingChatOptions.mergeToolNames(runtimeOptions.getToolNames(), this.defaultOptions.getToolNames())); requestOptions.setToolCallbacks(ToolCallingChatOptions.mergeToolCallbacks(runtimeOptions.getToolCallbacks(), this.defaultOptions.getToolCallbacks())); + requestOptions.setToolContext(ToolCallingChatOptions.mergeToolContext(runtimeOptions.getToolContext(), + this.defaultOptions.getToolContext())); } else { + requestOptions.setInternalToolExecutionEnabled(this.defaultOptions.isInternalToolExecutionEnabled()); requestOptions.setToolNames(this.defaultOptions.getToolNames()); requestOptions.setToolCallbacks(this.defaultOptions.getToolCallbacks()); + requestOptions.setToolContext(this.defaultOptions.getToolContext()); } // Validate request options diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java index fc815202b3f..c6706e337d0 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/api/OllamaOptions.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -331,7 +332,7 @@ public class OllamaOptions implements ToolCallingChatOptions, EmbeddingOptions { private Set toolNames = new HashSet<>(); @JsonIgnore - private Map toolContext; + private Map toolContext = new HashMap<>(); public static Builder builder() { return new Builder(); diff --git a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java index 9f38c6fa06a..740c299ab8a 100644 --- a/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java +++ b/models/spring-ai-ollama/src/test/java/org/springframework/ai/ollama/OllamaChatRequestTests.java @@ -20,8 +20,13 @@ import org.springframework.ai.chat.prompt.ChatOptions; import org.springframework.ai.chat.prompt.Prompt; +import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.ollama.api.OllamaApi; import org.springframework.ai.ollama.api.OllamaOptions; +import org.springframework.ai.tool.ToolCallback; +import org.springframework.ai.tool.definition.ToolDefinition; + +import java.util.Map; import static org.assertj.core.api.Assertions.assertThat; @@ -36,6 +41,37 @@ class OllamaChatRequestTests { .defaultOptions(OllamaOptions.builder().model("MODEL_NAME").topK(99).temperature(66.6).numGPU(1).build()) .build(); + @Test + void whenToolRuntimeOptionsThenMergeWithDefaults() { + OllamaOptions defaultOptions = OllamaOptions.builder() + .model("MODEL_NAME") + .internalToolExecutionEnabled(true) + .toolCallbacks(new TestToolCallback("tool1"), new TestToolCallback("tool2")) + .toolNames("tool1", "tool2") + .toolContext(Map.of("key1", "value1")) + .build(); + OllamaChatModel chatModel = OllamaChatModel.builder() + .ollamaApi(new OllamaApi()) + .defaultOptions(defaultOptions) + .build(); + + OllamaOptions runtimeOptions = OllamaOptions.builder() + .internalToolExecutionEnabled(false) + .toolCallbacks(new TestToolCallback("tool3"), new TestToolCallback("tool4")) + .toolNames("tool3") + .toolContext(Map.of("key2", "value2")) + .build(); + Prompt prompt = chatModel.buildRequestPrompt(new Prompt("Test message content", runtimeOptions)); + + assertThat(((ToolCallingChatOptions) prompt.getOptions())).isNotNull(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).isInternalToolExecutionEnabled()).isFalse(); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolCallbacks()).hasSize(4); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolNames()).containsExactlyInAnyOrder("tool1", + "tool2", "tool3"); + assertThat(((ToolCallingChatOptions) prompt.getOptions()).getToolContext()).containsEntry("key1", "value1") + .containsEntry("key2", "value2"); + } + @Test void createRequestWithDefaultOptions() { var prompt = this.chatModel.buildRequestPrompt(new Prompt("Test message content")); @@ -124,4 +160,24 @@ public void createRequestWithDefaultOptionsModelOverride() { assertThat(request.model()).isEqualTo("PROMPT_MODEL"); } + static class TestToolCallback implements ToolCallback { + + private final ToolDefinition toolDefinition; + + public TestToolCallback(String name) { + this.toolDefinition = ToolDefinition.builder().name(name).inputSchema("{}").build(); + } + + @Override + public ToolDefinition getToolDefinition() { + return toolDefinition; + } + + @Override + public String call(String toolInput) { + return "Mission accomplished!"; + } + + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java index 6be2cf37d37..b4c25f91172 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/model/tool/ToolCallingChatOptions.java @@ -24,6 +24,7 @@ import org.springframework.util.Assert; import java.util.ArrayList; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -204,4 +205,15 @@ static List mergeToolCallbacks(List runtimeT return mergedToolCallbacks; } + static Map mergeToolContext(Map runtimeToolContext, + Map defaultToolContext) { + Assert.notNull(runtimeToolContext, "runtimeToolContext cannot be null"); + Assert.noNullElements(runtimeToolContext.keySet(), "runtimeToolContext keys cannot be null"); + Assert.notNull(defaultToolContext, "defaultToolContext cannot be null"); + Assert.noNullElements(defaultToolContext.keySet(), "defaultToolContext keys cannot be null"); + var mergedToolContext = new HashMap<>(defaultToolContext); + mergedToolContext.putAll(runtimeToolContext); + return mergedToolContext; + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java index c3f92df2580..134151ab0b9 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/model/tool/ToolCallingChatOptionsTests.java @@ -22,6 +22,7 @@ import org.springframework.ai.tool.definition.ToolDefinition; import java.util.List; +import java.util.Map; import java.util.Set; import static org.assertj.core.api.Assertions.assertThat; @@ -141,6 +142,47 @@ void whenMergeEmptyRuntimeAndEmptyDefaultToolCallbacks() { assertThat(mergedToolCallbacks).hasSize(0); } + @Test + void whenMergeRuntimeAndDefaultToolContext() { + Map runtimeToolContext = Map.of("key1", "value1", "key2", "value2"); + Map defaultToolContext = Map.of("key1", "valueA", "key3", "value3"); + Map mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext, + defaultToolContext); + assertThat(mergedToolContext).hasSize(3); + assertThat(mergedToolContext).containsEntry("key1", "value1") + .containsEntry("key2", "value2") + .containsEntry("key3", "value3"); + } + + @Test + void whenMergeRuntimeAndEmptyDefaultToolContext() { + Map runtimeToolContext = Map.of("key1", "value1", "key2", "value2"); + Map defaultToolContext = Map.of(); + Map mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext, + defaultToolContext); + assertThat(mergedToolContext).hasSize(2); + assertThat(mergedToolContext).containsEntry("key1", "value1").containsEntry("key2", "value2"); + } + + @Test + void whenMergeEmptyRuntimeAndDefaultToolContext() { + Map runtimeToolContext = Map.of(); + Map defaultToolContext = Map.of("key1", "value1", "key2", "value2"); + Map mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext, + defaultToolContext); + assertThat(mergedToolContext).hasSize(2); + assertThat(mergedToolContext).containsEntry("key1", "value1").containsEntry("key2", "value2"); + } + + @Test + void whenMergeEmptyRuntimeAndEmptyDefaultToolContext() { + Map runtimeToolContext = Map.of(); + Map defaultToolContext = Map.of(); + Map mergedToolContext = ToolCallingChatOptions.mergeToolContext(runtimeToolContext, + defaultToolContext); + assertThat(mergedToolContext).hasSize(0); + } + static class TestToolCallback implements ToolCallback { private final ToolDefinition toolDefinition;