diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java index 959a971bf85..0b83d3275a5 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatOptions.java @@ -60,6 +60,7 @@ public class AnthropicChatOptions implements ToolCallingChatOptions { private @JsonProperty("temperature") Double temperature; private @JsonProperty("top_p") Double topP; private @JsonProperty("top_k") Integer topK; + private @JsonProperty("tool_choice") AnthropicApi.ToolChoice toolChoice; private @JsonProperty("thinking") ChatCompletionRequest.ThinkingConfig thinking; @JsonIgnore @@ -117,6 +118,7 @@ public static AnthropicChatOptions fromOptions(AnthropicChatOptions fromOptions) .temperature(fromOptions.getTemperature()) .topP(fromOptions.getTopP()) .topK(fromOptions.getTopK()) + .toolChoice(fromOptions.getToolChoice()) .thinking(fromOptions.getThinking()) .toolCallbacks( fromOptions.getToolCallbacks() != null ? new ArrayList<>(fromOptions.getToolCallbacks()) : null) @@ -190,6 +192,14 @@ public void setTopK(Integer topK) { this.topK = topK; } + public AnthropicApi.ToolChoice getToolChoice() { + return this.toolChoice; + } + + public void setToolChoice(AnthropicApi.ToolChoice toolChoice) { + this.toolChoice = toolChoice; + } + public ChatCompletionRequest.ThinkingConfig getThinking() { return this.thinking; } @@ -291,7 +301,8 @@ public boolean equals(Object o) { && Objects.equals(this.metadata, that.metadata) && Objects.equals(this.stopSequences, that.stopSequences) && Objects.equals(this.temperature, that.temperature) && Objects.equals(this.topP, that.topP) - && Objects.equals(this.topK, that.topK) && Objects.equals(this.thinking, that.thinking) + && Objects.equals(this.topK, that.topK) && Objects.equals(this.toolChoice, that.toolChoice) + && Objects.equals(this.thinking, that.thinking) && Objects.equals(this.toolCallbacks, that.toolCallbacks) && Objects.equals(this.toolNames, that.toolNames) && Objects.equals(this.internalToolExecutionEnabled, that.internalToolExecutionEnabled) @@ -303,8 +314,8 @@ public boolean equals(Object o) { @Override public int hashCode() { return Objects.hash(this.model, this.maxTokens, this.metadata, this.stopSequences, this.temperature, this.topP, - this.topK, this.thinking, this.toolCallbacks, this.toolNames, this.internalToolExecutionEnabled, - this.toolContext, this.httpHeaders, this.cacheOptions); + this.topK, this.toolChoice, this.thinking, this.toolCallbacks, this.toolNames, + this.internalToolExecutionEnabled, this.toolContext, this.httpHeaders, this.cacheOptions); } public static final class Builder { @@ -351,6 +362,11 @@ public Builder topK(Integer topK) { return this; } + public Builder toolChoice(AnthropicApi.ToolChoice toolChoice) { + this.options.toolChoice = toolChoice; + return this; + } + public Builder thinking(ChatCompletionRequest.ThinkingConfig thinking) { this.options.thinking = thinking; return this; diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java index f60a7deb34f..cdd1a4fef51 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/api/AnthropicApi.java @@ -512,6 +512,8 @@ public interface StreamEvent { * return tool_use content blocks that represent the model's use of those tools. You * can then run those tools using the tool input generated by the model and then * optionally return results back to the model using tool_result content blocks. + * @param toolChoice How the model should use the provided tools. The model can use a + * specific tool, any available tool, decide by itself, or not use tools at all. * @param thinking Configuration for the model's thinking mode. When enabled, the * model can perform more in-depth reasoning before responding to a query. */ @@ -529,17 +531,19 @@ public record ChatCompletionRequest( @JsonProperty("top_p") Double topP, @JsonProperty("top_k") Integer topK, @JsonProperty("tools") List tools, + @JsonProperty("tool_choice") ToolChoice toolChoice, @JsonProperty("thinking") ThinkingConfig thinking) { // @formatter:on public ChatCompletionRequest(String model, List messages, Object system, Integer maxTokens, Double temperature, Boolean stream) { - this(model, messages, system, maxTokens, null, null, stream, temperature, null, null, null, null); + this(model, messages, system, maxTokens, null, null, stream, temperature, null, null, null, null, null); } public ChatCompletionRequest(String model, List messages, Object system, Integer maxTokens, List stopSequences, Double temperature, Boolean stream) { - this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null, null); + this(model, messages, system, maxTokens, null, stopSequences, stream, temperature, null, null, null, null, + null); } public static ChatCompletionRequestBuilder builder() { @@ -613,6 +617,8 @@ public static final class ChatCompletionRequestBuilder { private List tools; + private ToolChoice toolChoice; + private ChatCompletionRequest.ThinkingConfig thinking; private ChatCompletionRequestBuilder() { @@ -630,6 +636,7 @@ private ChatCompletionRequestBuilder(ChatCompletionRequest request) { this.topP = request.topP; this.topK = request.topK; this.tools = request.tools; + this.toolChoice = request.toolChoice; this.thinking = request.thinking; } @@ -693,6 +700,11 @@ public ChatCompletionRequestBuilder tools(List tools) { return this; } + public ChatCompletionRequestBuilder toolChoice(ToolChoice toolChoice) { + this.toolChoice = toolChoice; + return this; + } + public ChatCompletionRequestBuilder thinking(ChatCompletionRequest.ThinkingConfig thinking) { this.thinking = thinking; return this; @@ -705,7 +717,8 @@ public ChatCompletionRequestBuilder thinking(ThinkingType type, Integer budgetTo public ChatCompletionRequest build() { return new ChatCompletionRequest(this.model, this.messages, this.system, this.maxTokens, this.metadata, - this.stopSequences, this.stream, this.temperature, this.topP, this.topK, this.tools, this.thinking); + this.stopSequences, this.stream, this.temperature, this.topP, this.topK, this.tools, + this.toolChoice, this.thinking); } } @@ -1135,6 +1148,126 @@ public Tool(String name, String description, Map inputSchema) { } + /** + * Base interface for tool choice options. + */ + @JsonTypeInfo(use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.EXISTING_PROPERTY, property = "type", + visible = true) + @JsonSubTypes({ @JsonSubTypes.Type(value = ToolChoiceAuto.class, name = "auto"), + @JsonSubTypes.Type(value = ToolChoiceAny.class, name = "any"), + @JsonSubTypes.Type(value = ToolChoiceTool.class, name = "tool"), + @JsonSubTypes.Type(value = ToolChoiceNone.class, name = "none") }) + public interface ToolChoice { + + @JsonProperty("type") + String type(); + + } + + /** + * Auto tool choice - the model will automatically decide whether to use tools. + * + * @param type The type of tool choice, always "auto". + * @param disableParallelToolUse Whether to disable parallel tool use. Defaults to + * false. If set to true, the model will output at most one tool use. + */ + @JsonInclude(Include.NON_NULL) + public record ToolChoiceAuto(@JsonProperty("type") String type, + @JsonProperty("disable_parallel_tool_use") Boolean disableParallelToolUse) implements ToolChoice { + + /** + * Create an auto tool choice with default settings. + */ + public ToolChoiceAuto() { + this("auto", null); + } + + /** + * Create an auto tool choice with specific parallel tool use setting. + * @param disableParallelToolUse Whether to disable parallel tool use. + */ + public ToolChoiceAuto(Boolean disableParallelToolUse) { + this("auto", disableParallelToolUse); + } + + } + + /** + * Any tool choice - the model will use any available tools. + * + * @param type The type of tool choice, always "any". + * @param disableParallelToolUse Whether to disable parallel tool use. Defaults to + * false. If set to true, the model will output exactly one tool use. + */ + @JsonInclude(Include.NON_NULL) + public record ToolChoiceAny(@JsonProperty("type") String type, + @JsonProperty("disable_parallel_tool_use") Boolean disableParallelToolUse) implements ToolChoice { + + /** + * Create an any tool choice with default settings. + */ + public ToolChoiceAny() { + this("any", null); + } + + /** + * Create an any tool choice with specific parallel tool use setting. + * @param disableParallelToolUse Whether to disable parallel tool use. + */ + public ToolChoiceAny(Boolean disableParallelToolUse) { + this("any", disableParallelToolUse); + } + + } + + /** + * Tool choice - the model will use the specified tool. + * + * @param type The type of tool choice, always "tool". + * @param name The name of the tool to use. + * @param disableParallelToolUse Whether to disable parallel tool use. Defaults to + * false. If set to true, the model will output exactly one tool use. + */ + @JsonInclude(Include.NON_NULL) + public record ToolChoiceTool(@JsonProperty("type") String type, @JsonProperty("name") String name, + @JsonProperty("disable_parallel_tool_use") Boolean disableParallelToolUse) implements ToolChoice { + + /** + * Create a tool choice for a specific tool. + * @param name The name of the tool to use. + */ + public ToolChoiceTool(String name) { + this("tool", name, null); + } + + /** + * Create a tool choice for a specific tool with parallel tool use setting. + * @param name The name of the tool to use. + * @param disableParallelToolUse Whether to disable parallel tool use. + */ + public ToolChoiceTool(String name, Boolean disableParallelToolUse) { + this("tool", name, disableParallelToolUse); + } + + } + + /** + * None tool choice - the model will not be allowed to use tools. + * + * @param type The type of tool choice, always "none". + */ + @JsonInclude(Include.NON_NULL) + public record ToolChoiceNone(@JsonProperty("type") String type) implements ToolChoice { + + /** + * Create a none tool choice. + */ + public ToolChoiceNone() { + this("none"); + } + + } + // CB START EVENT /** diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java index 6570d5ee6a6..ae14ad3a476 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatModelIT.java @@ -491,6 +491,102 @@ void testToolUseContentBlock() { } } + @Test + void testToolChoiceAny() { + // A user question that would not typically result in a tool request + UserMessage userMessage = new UserMessage("Say hi"); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = AnthropicChatOptions.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()) + .toolChoice(new AnthropicApi.ToolChoiceAny()) + .internalToolExecutionEnabled(false) + .toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description( + "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") + .inputType(MockWeatherService.Request.class) + .build()) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + assertThat(response.getResults()).isNotNull(); + // When tool choice is "any", the model MUST use at least one tool + boolean hasToolCalls = response.getResults() + .stream() + .anyMatch(generation -> !generation.getOutput().getToolCalls().isEmpty()); + assertThat(hasToolCalls).isTrue(); + } + + @Test + void testToolChoiceTool() { + UserMessage userMessage = new UserMessage( + "What's the weather like in San Francisco? Return the result in Celsius."); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = AnthropicChatOptions.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()) + .toolChoice(new AnthropicApi.ToolChoiceTool("getFunResponse", true)) + .internalToolExecutionEnabled(false) + .toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description( + "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") + .inputType(MockWeatherService.Request.class) + .build(), + // Based on the user's question the model should want to call + // getCurrentWeather + // however we're going to force getFunResponse + FunctionToolCallback.builder("getFunResponse", new MockWeatherService()) + .description("Get a fun response") + .inputType(MockWeatherService.Request.class) + .build()) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + assertThat(response.getResults()).isNotNull(); + // When tool choice is a specific tool, the model MUST use that specific tool + List allToolCalls = response.getResults() + .stream() + .flatMap(generation -> generation.getOutput().getToolCalls().stream()) + .toList(); + assertThat(allToolCalls).isNotEmpty(); + assertThat(allToolCalls).hasSize(1); + assertThat(allToolCalls.get(0).name()).isEqualTo("getFunResponse"); + } + + @Test + void testToolChoiceNone() { + UserMessage userMessage = new UserMessage("What's the weather like in San Francisco?"); + + List messages = new ArrayList<>(List.of(userMessage)); + + var promptOptions = AnthropicChatOptions.builder() + .model(AnthropicApi.ChatModel.CLAUDE_3_5_SONNET.getName()) + .toolChoice(new AnthropicApi.ToolChoiceNone()) + .toolCallbacks(FunctionToolCallback.builder("getCurrentWeather", new MockWeatherService()) + .description( + "Get the weather in location. Return temperature in 36°F or 36°C format. Use multi-turn if needed.") + .inputType(MockWeatherService.Request.class) + .build()) + .build(); + + ChatResponse response = this.chatModel.call(new Prompt(messages, promptOptions)); + + logger.info("Response: {}", response); + assertThat(response.getResults()).isNotNull(); + // When tool choice is "none", the model MUST NOT use any tools + List allToolCalls = response.getResults() + .stream() + .flatMap(generation -> generation.getOutput().getToolCalls().stream()) + .toList(); + assertThat(allToolCalls).isEmpty(); + } + record ActorsFilmsRecord(String actor, List movies) { } diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java index 53bb771319a..75d6edc1379 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/ChatCompletionRequestTests.java @@ -60,4 +60,69 @@ public void createRequestWithChatOptions() { assertThat(request.temperature()).isEqualTo(99.9); } + @Test + public void createRequestWithToolChoice() { + + var client = AnthropicChatModel.builder() + .anthropicApi(AnthropicApi.builder().apiKey("TEST").build()) + .defaultOptions(AnthropicChatOptions.builder().model("DEFAULT_MODEL").build()) + .build(); + + // Test with ToolChoiceAuto + var autoToolChoice = new AnthropicApi.ToolChoiceAuto(); + var prompt = client.buildRequestPrompt( + new Prompt("Test message content", AnthropicChatOptions.builder().toolChoice(autoToolChoice).build())); + + var request = client.createRequest(prompt, false); + + assertThat(request.toolChoice()).isNotNull(); + assertThat(request.toolChoice()).isInstanceOf(AnthropicApi.ToolChoiceAuto.class); + assertThat(request.toolChoice().type()).isEqualTo("auto"); + + // Test with ToolChoiceAny + var anyToolChoice = new AnthropicApi.ToolChoiceAny(); + prompt = client.buildRequestPrompt( + new Prompt("Test message content", AnthropicChatOptions.builder().toolChoice(anyToolChoice).build())); + + request = client.createRequest(prompt, false); + + assertThat(request.toolChoice()).isNotNull(); + assertThat(request.toolChoice()).isInstanceOf(AnthropicApi.ToolChoiceAny.class); + assertThat(request.toolChoice().type()).isEqualTo("any"); + + // Test with ToolChoiceTool + var specificToolChoice = new AnthropicApi.ToolChoiceTool("get_weather"); + prompt = client.buildRequestPrompt(new Prompt("Test message content", + AnthropicChatOptions.builder().toolChoice(specificToolChoice).build())); + + request = client.createRequest(prompt, false); + + assertThat(request.toolChoice()).isNotNull(); + assertThat(request.toolChoice()).isInstanceOf(AnthropicApi.ToolChoiceTool.class); + assertThat(request.toolChoice().type()).isEqualTo("tool"); + assertThat(((AnthropicApi.ToolChoiceTool) request.toolChoice()).name()).isEqualTo("get_weather"); + + // Test with ToolChoiceNone + var noneToolChoice = new AnthropicApi.ToolChoiceNone(); + prompt = client.buildRequestPrompt( + new Prompt("Test message content", AnthropicChatOptions.builder().toolChoice(noneToolChoice).build())); + + request = client.createRequest(prompt, false); + + assertThat(request.toolChoice()).isNotNull(); + assertThat(request.toolChoice()).isInstanceOf(AnthropicApi.ToolChoiceNone.class); + assertThat(request.toolChoice().type()).isEqualTo("none"); + + // Test with disableParallelToolUse + var autoWithDisabledParallel = new AnthropicApi.ToolChoiceAuto(true); + prompt = client.buildRequestPrompt(new Prompt("Test message content", + AnthropicChatOptions.builder().toolChoice(autoWithDisabledParallel).build())); + + request = client.createRequest(prompt, false); + + assertThat(request.toolChoice()).isNotNull(); + assertThat(request.toolChoice()).isInstanceOf(AnthropicApi.ToolChoiceAuto.class); + assertThat(((AnthropicApi.ToolChoiceAuto) request.toolChoice()).disableParallelToolUse()).isTrue(); + } + } diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc index 44d9146e8ee..299bc0d534b 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/anthropic-chat.adoc @@ -803,6 +803,135 @@ You can register custom Java Tools with the `AnthropicChatModel` and have the An This is a powerful technique to connect the LLM capabilities with external tools and APIs. Read more about xref:api/tools.adoc[Tool Calling]. +=== Tool Choice + +The `tool_choice` parameter allows you to control how the model uses the provided tools. This feature gives you fine-grained control over tool execution behavior. + +For complete API details, see the https://docs.anthropic.com/en/api/messages#body-tool-choice[Anthropic tool_choice documentation]. + +==== Tool Choice Options + +Spring AI provides four tool choice strategies through the `AnthropicApi.ToolChoice` interface: + +* **`ToolChoiceAuto`** (default): The model automatically decides whether to use tools or respond with text +* **`ToolChoiceAny`**: The model must use at least one of the available tools +* **`ToolChoiceTool`**: The model must use a specific tool by name +* **`ToolChoiceNone`**: The model cannot use any tools + +==== Disabling Parallel Tool Use + +All tool choice options (except `ToolChoiceNone`) support a `disableParallelToolUse` parameter. When set to `true`, the model will output at most one tool use. + +==== Usage Examples + +===== Auto Mode (Default Behavior) + +Let the model decide whether to use tools: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "What's the weather in San Francisco?", + AnthropicChatOptions.builder() + .toolChoice(new AnthropicApi.ToolChoiceAuto()) + .toolCallbacks(weatherToolCallback) + .build() + ) +); +---- + +===== Force Tool Use (Any) + +Require the model to use at least one tool: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "What's the weather?", + AnthropicChatOptions.builder() + .toolChoice(new AnthropicApi.ToolChoiceAny()) + .toolCallbacks(weatherToolCallback, calculatorToolCallback) + .build() + ) +); +---- + +===== Force Specific Tool + +Require the model to use a specific tool by name: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "What's the weather in San Francisco?", + AnthropicChatOptions.builder() + .toolChoice(new AnthropicApi.ToolChoiceTool("get_weather")) + .toolCallbacks(weatherToolCallback, calculatorToolCallback) + .build() + ) +); +---- + +===== Disable Tool Use + +Prevent the model from using any tools: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "What's the weather in San Francisco?", + AnthropicChatOptions.builder() + .toolChoice(new AnthropicApi.ToolChoiceNone()) + .toolCallbacks(weatherToolCallback) + .build() + ) +); +---- + +===== Disable Parallel Tool Use + +Force the model to use only one tool at a time: + +[source,java] +---- +ChatResponse response = chatModel.call( + new Prompt( + "What's the weather in San Francisco and what's 2+2?", + AnthropicChatOptions.builder() + .toolChoice(new AnthropicApi.ToolChoiceAuto(true)) // disableParallelToolUse = true + .toolCallbacks(weatherToolCallback, calculatorToolCallback) + .build() + ) +); +---- + +==== Using ChatClient API + +You can also use tool choice with the fluent ChatClient API: + +[source,java] +---- +String response = ChatClient.create(chatModel) + .prompt() + .user("What's the weather in San Francisco?") + .options(AnthropicChatOptions.builder() + .toolChoice(new AnthropicApi.ToolChoiceTool("get_weather")) + .build()) + .call() + .content(); +---- + +==== Use Cases + +* **Validation**: Use `ToolChoiceTool` to ensure a specific tool is called for critical operations +* **Efficiency**: Use `ToolChoiceAny` when you know a tool must be used to avoid unnecessary text generation +* **Control**: Use `ToolChoiceNone` to temporarily disable tool access while keeping tool definitions registered +* **Sequential Processing**: Use `disableParallelToolUse` to force sequential tool execution for dependent operations + == Multimodal Multimodality refers to a model's ability to simultaneously understand and process information from various sources, including text, pdf, images, data formats.