From faf0c8b565c3e09a1809c21ca9c24070d87772e2 Mon Sep 17 00:00:00 2001 From: SenreySong <25841017+senreysong@users.noreply.github.com> Date: Thu, 4 Sep 2025 15:40:21 +0800 Subject: [PATCH 1/5] Add support for extraBody and reasoningContent - OpenAiChatOptions.extraBody: Support for custom parameter extension - ChatCompletionMessage.reasoningContent: Support for reasoning content field Signed-off-by: SenreySong <25841017+SenreySong@users.noreply.github.com> --- .../ai/openai/OpenAiChatModel.java | 224 +++++---- .../ai/openai/OpenAiChatOptions.java | 466 +++++++++--------- .../ai/openai/api/OpenAiApi.java | 283 ++++++++--- .../OpenAiStreamFunctionCallingHelper.java | 7 +- .../ai/openai/api/OpenAiApiBuilderTests.java | 11 +- .../ai/openai/api/OpenAiApiIT.java | 16 +- ...OpenAiStreamFunctionCallingHelperTest.java | 14 +- .../api/tool/OpenAiApiToolFunctionCallIT.java | 10 +- .../ai/openai/chat/OpenAiRetryTests.java | 16 +- .../OpenAiStreamingFinishReasonTests.java | 21 +- 10 files changed, 633 insertions(+), 435 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index cb0fed3e549..038d2e1526f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -16,6 +16,9 @@ package org.springframework.ai.openai; +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import java.util.ArrayList; import java.util.Base64; import java.util.HashMap; @@ -23,16 +26,8 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; - -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; - import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -86,6 +81,9 @@ import org.springframework.util.MimeTypeUtils; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI} @@ -118,24 +116,16 @@ public class OpenAiChatModel implements ChatModel { private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); - /** - * The default options used for the chat completion requests. - */ + /** The default options used for the chat completion requests. */ private final OpenAiChatOptions defaultOptions; - /** - * The retry template used to retry the OpenAI API calls. - */ + /** The retry template used to retry the OpenAI API calls. */ private final RetryTemplate retryTemplate; - /** - * Low-level access to the OpenAI API. - */ + /** Low-level access to the OpenAI API. */ private final OpenAiApi openAiApi; - /** - * Observation registry used for instrumentation. - */ + /** Observation registry used for instrumentation. */ private final ObservationRegistry observationRegistry; private final ToolCallingManager toolCallingManager; @@ -146,9 +136,7 @@ public class OpenAiChatModel implements ChatModel { */ private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; - /** - * Conventions to use for generating observations. - */ + /** Conventions to use for generating observations. */ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, @@ -195,7 +183,6 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity = this.retryTemplate .execute(ctx -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt))); @@ -213,17 +200,32 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons } // @formatter:off - List generations = choices.stream().map(choice -> { - Map metadata = Map.of( - "id", chatCompletion.id() != null ? chatCompletion.id() : "", - "role", choice.message().role() != null ? choice.message().role().name() : "", - "index", choice.index() != null ? choice.index() : 0, - "finishReason", getFinishReasonJson(choice.finishReason()), - "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", - "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of(Map.of())); - return buildGeneration(choice, metadata, request); - }).toList(); - // @formatter:on + List generations = + choices.stream() + .map( + choice -> { + Map metadata = + Map.of( + "id", + chatCompletion.id() != null ? chatCompletion.id() : "", + "role", + choice.message().role() != null + ? choice.message().role().name() + : "", + "index", choice.index() != null ? choice.index() : 0, + "finishReason", getFinishReasonJson(choice.finishReason()), + "refusal", + StringUtils.hasText(choice.message().refusal()) + ? choice.message().refusal() + : "", + "annotations", + choice.message().annotations() != null + ? choice.message().annotations() + : List.of(Map.of())); + return buildGeneration(choice, metadata, request); + }) + .toList(); + // @formatter:on RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); @@ -238,7 +240,6 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons observationContext.setResponse(chatResponse); return chatResponse; - }); if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { @@ -306,23 +307,49 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha Flux chatResponse = completionChunks.map(this::chunkToChatCompletion) .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { try { - // If an id is not provided, set to "NO_ID" (for compatible APIs). + // If an id is not provided, set to "NO_ID" (for compatible + // APIs). String id = chatCompletion2.id() == null ? "NO_ID" : chatCompletion2.id(); List generations = chatCompletion2.choices().stream().map(choice -> { // @formatter:off - if (choice.message().role() != null) { - roleMap.putIfAbsent(id, choice.message().role().name()); - } - Map metadata = Map.of( - "id", id, - "role", roleMap.getOrDefault(id, ""), - "index", choice.index() != null ? choice.index() : 0, - "finishReason", getFinishReasonJson(choice.finishReason()), - "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", - "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of()); - return buildGeneration(choice, metadata, request); - }).toList(); - // @formatter:on + if (choice.message().role() != null) { + roleMap.putIfAbsent( + id, choice.message().role().name()); + } + Map metadata = + Map.of( + "id", + id, + "role", + roleMap.getOrDefault(id, ""), + "index", + choice.index() != null + ? choice.index() + : 0, + "finishReason", + getFinishReasonJson( + choice.finishReason()), + "refusal", + StringUtils.hasText( + choice.message().refusal()) + ? choice.message().refusal() + : "", + "annotations", + choice.message().annotations() != null + ? choice.message().annotations() + : List.of(), + "reasoningContent", + choice.message().reasoningContent() + != null + ? choice + .message() + .reasoningContent() + : ""); + return buildGeneration( + choice, metadata, request); + }) + .toList(); + // @formatter:on OpenAiApi.Usage usage = chatCompletion2.usage(); Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage(); Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, @@ -333,7 +360,8 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha logger.error("Error processing chat completion", e); return new ChatResponse(List.of()); } - // When in stream mode and enabled to include the usage, the OpenAI + // When in stream mode and enabled to include the usage, the + // OpenAI // Chat completion response would have the usage set only in its // final response. Hence, the following overlapping buffer is // created to store both the current and the subsequent response @@ -362,43 +390,54 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha }); // @formatter:off - Flux flux = chatResponse.flatMap(response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { - // FIXME: bounded elastic needs to be used since tool calling - // is currently only synchronous - return Flux.deferContextual(ctx -> { - ToolExecutionResult toolExecutionResult; - try { - ToolCallReactiveContextHolder.setContext(ctx); - toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); - } - finally { - ToolCallReactiveContextHolder.clearContext(); - } - if (toolExecutionResult.returnDirect()) { - // Return tool execution result directly to the client. - return Flux.just(ChatResponse.builder().from(response) - .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) - .build()); - } - else { - // Send the tool execution result back to the model. - return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), - response); - } - }).subscribeOn(Schedulers.boundedElastic()); - } - else { - return Flux.just(response); - } - }) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); - // @formatter:on + Flux flux = + chatResponse + .flatMap( + response -> { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired( + prompt.getOptions(), response)) { + // FIXME: bounded elastic needs to be used since tool calling + // is currently only synchronous + return Flux.deferContextual( + ctx -> { + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = + this.toolCallingManager.executeToolCalls( + prompt, response); + } finally { + ToolCallReactiveContextHolder.clearContext(); + } + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return Flux.just( + ChatResponse.builder() + .from(response) + .generations( + ToolExecutionResult.buildGenerations( + toolExecutionResult)) + .build()); + } else { + // Send the tool execution result back to the model. + return this.internalStream( + new Prompt( + toolExecutionResult.conversationHistory(), + prompt.getOptions()), + response); + } + }) + .subscribeOn(Schedulers.boundedElastic()); + } else { + return Flux.just(response); + } + }) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + // @formatter:on return new MessageAggregator().aggregate(flux, observationContext::setResponse); - }); } @@ -563,9 +602,7 @@ private Map mergeHttpHeaders(Map runtimeHttpHead return mergedHttpHeaders; } - /** - * Accessible for testing. - */ + /** Accessible for testing. */ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { List chatCompletionMessages = prompt.getInstructions().stream().map(message -> { @@ -598,10 +635,9 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { Assert.isTrue(assistantMessage.getMedia().size() == 1, "Only one media content is supported for assistant messages"); audioOutput = new AudioOutput(assistantMessage.getMedia().get(0).getId(), null, null, null); - } return List.of(new ChatCompletionMessage(assistantMessage.getText(), - ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null)); + ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null, null)); } else if (message.getMessageType() == MessageType.TOOL) { ToolResponseMessage toolMessage = (ToolResponseMessage) message; @@ -611,7 +647,7 @@ else if (message.getMessageType() == MessageType.TOOL) { return toolMessage.getResponses() .stream() .map(tr -> new ChatCompletionMessage(tr.responseData(), ChatCompletionMessage.Role.TOOL, tr.name(), - tr.id(), null, null, null, null)) + tr.id(), null, null, null, null, null)) .toList(); } else { @@ -715,9 +751,7 @@ public static Builder builder() { return new Builder(); } - /** - * Returns a builder pre-populated with the current configuration for mutation. - */ + /** Returns a builder pre-populated with the current configuration for mutation. */ public Builder mutate() { return new Builder(this); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 14b5ba42536..d3da58b346c 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -16,6 +16,10 @@ package org.springframework.ai.openai; +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -24,14 +28,8 @@ import java.util.Map; import java.util.Objects; import java.util.Set; - -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.openai.api.OpenAiApi; @@ -60,211 +58,228 @@ public class OpenAiChatOptions implements ToolCallingChatOptions { private static final Logger logger = LoggerFactory.getLogger(OpenAiChatOptions.class); // @formatter:off - /** - * ID of the model to use. - */ - private @JsonProperty("model") String model; - /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing - * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. - */ - private @JsonProperty("frequency_penalty") Double frequencyPenalty; - /** - * Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object - * that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. - * Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will - * vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 - * or 100 should result in a ban or exclusive selection of the relevant token. - */ - private @JsonProperty("logit_bias") Map logitBias; - /** - * Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities - * of each output token returned in the 'content' of 'message'. - */ - private @JsonProperty("logprobs") Boolean logprobs; - /** - * An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, - * each with an associated log probability. 'logprobs' must be set to 'true' if this parameter is used. - */ - private @JsonProperty("top_logprobs") Integer topLogprobs; - /** - * The maximum number of tokens to generate in the chat completion. - * The total length of input tokens and generated tokens is limited by the model's context length. - * - *

Model-specific usage:

- *
    - *
  • Use for non-reasoning models (e.g., gpt-4o, gpt-3.5-turbo)
  • - *
  • Cannot be used with reasoning models (e.g., o1, o3, o4-mini series)
  • - *
- * - *

Mutual exclusivity: This parameter cannot be used together with - * {@link #maxCompletionTokens}. Setting both will result in an API error.

- */ - private @JsonProperty("max_tokens") Integer maxTokens; - /** - * An upper bound for the number of tokens that can be generated for a completion, - * including visible output tokens and reasoning tokens. - * - *

Model-specific usage:

- *
    - *
  • Required for reasoning models (e.g., o1, o3, o4-mini series)
  • - *
  • Cannot be used with non-reasoning models (e.g., gpt-4o, gpt-3.5-turbo)
  • - *
- * - *

Mutual exclusivity: This parameter cannot be used together with - * {@link #maxTokens}. Setting both will result in an API error.

- */ - private @JsonProperty("max_completion_tokens") Integer maxCompletionTokens; - /** - * How many chat completion choices to generate for each input message. Note that you will be charged based - * on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. - */ - private @JsonProperty("n") Integer n; - - /** - * Output types that you would like the model to generate for this request. - * Most models are capable of generating text, which is the default. - * The gpt-4o-audio-preview model can also be used to generate audio. - * To request that this model generate both text and audio responses, - * you can use: ["text", "audio"]. - * Note that the audio modality is only available for the gpt-4o-audio-preview model - * and is not supported for streaming completions. - */ - private @JsonProperty("modalities") List outputModalities; - - /** - * Audio parameters for the audio generation. Required when audio output is requested with - * modalities: ["audio"] - * Note: that the audio modality is only available for the gpt-4o-audio-preview model - * and is not supported for streaming completions. - - */ - private @JsonProperty("audio") AudioParameters outputAudio; - - /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they - * appear in the text so far, increasing the model's likelihood to talk about new topics. - */ - private @JsonProperty("presence_penalty") Double presencePenalty; - /** - * An object specifying the format that the model must output. Setting to { "type": - * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. - */ - private @JsonProperty("response_format") ResponseFormat responseFormat; - /** - * Options for streaming response. Included in the API only if streaming-mode completion is requested. - */ - private @JsonProperty("stream_options") StreamOptions streamOptions; - /** - * This feature is in Beta. If specified, our system will make a best effort to sample - * deterministically, such that repeated requests with the same seed and parameters should return the same result. - * Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor - * changes in the backend. - */ - private @JsonProperty("seed") Integer seed; - /** - * Up to 4 sequences where the API will stop generating further tokens. - */ - private @JsonProperty("stop") List stop; - /** - * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output - * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend - * altering this or top_p but not both. - */ - private @JsonProperty("temperature") Double temperature; - /** - * An alternative to sampling with temperature, called nucleus sampling, where the model considers the - * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% - * probability mass are considered. We generally recommend altering this or temperature but not both. - */ - private @JsonProperty("top_p") Double topP; - /** - * A list of tools the model may call. Currently, only functions are supported as a tool. Use this to - * provide a list of functions the model may generate JSON inputs for. - */ - private @JsonProperty("tools") List tools; - /** - * Controls which (if any) function is called by the model. none means the model will not call a - * function and instead generates a message. auto means the model can pick between generating a message or calling a - * function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces - * the model to call that function. none is the default when no functions are present. auto is the default if - * functions are present. Use the {@link ToolChoiceBuilder} to create a tool choice object. - */ - private @JsonProperty("tool_choice") Object toolChoice; - /** - * A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. - */ - private @JsonProperty("user") String user; - /** - * Whether to enable parallel function calling during tool use. - * Defaults to true. - */ - private @JsonProperty("parallel_tool_calls") Boolean parallelToolCalls; - /** - * Whether to store the output of this chat completion request for use in our model distillation or evals products. - */ - private @JsonProperty("store") Boolean store; - - /** - * Developer-defined tags and values used for filtering completions in the dashboard. - */ - private @JsonProperty("metadata") Map metadata; - - /** - * Constrains effort on reasoning for reasoning models. Currently supported values are low, medium, and high. - * Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response. - * Optional. Defaults to medium. - * Only for 'o1' models. - */ - private @JsonProperty("reasoning_effort") String reasoningEffort; - - /** - * verbosity: string or null - * Optional - Defaults to medium - * Constrains the verbosity of the model's response. Lower values will result in more concise responses, while higher values will result in more verbose responses. - * Currently supported values are low, medium, and high. - * If specified, the model will use web search to find relevant information to answer the user's question. - */ - private @JsonProperty("verbosity") String verbosity; - - /** - * This tool searches the web for relevant results to use in a response. - */ - private @JsonProperty("web_search_options") WebSearchOptions webSearchOptions; - - /** - * Specifies the processing type used for serving the request. - */ - private @JsonProperty("service_tier") String serviceTier; - - /** - * Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests. - */ - @JsonIgnore - private List toolCallbacks = new ArrayList<>(); - - /** - * Collection of tool names to be resolved at runtime and used for tool calling in the chat completion requests. - */ - @JsonIgnore - private Set toolNames = new HashSet<>(); - - /** - * Whether to enable the tool execution lifecycle internally in ChatModel. - */ - @JsonIgnore - private Boolean internalToolExecutionEnabled; - - /** - * Optional HTTP headers to be added to the chat completion request. - */ - @JsonIgnore - private Map httpHeaders = new HashMap<>(); - - @JsonIgnore - private Map toolContext = new HashMap<>(); - - // @formatter:on + /** ID of the model to use. */ + private @JsonProperty("model") String model; + + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line + * verbatim. + */ + private @JsonProperty("frequency_penalty") Double frequencyPenalty; + + /** + * Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object + * that maps tokens (specified by their token ID in the tokenizer) to an associated bias value + * from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior + * to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease + * or increase likelihood of selection; values like -100 or 100 should result in a ban or + * exclusive selection of the relevant token. + */ + private @JsonProperty("logit_bias") Map logitBias; + + /** + * Whether to return log probabilities of the output tokens or not. If true, returns the log + * probabilities of each output token returned in the 'content' of 'message'. + */ + private @JsonProperty("logprobs") Boolean logprobs; + + /** + * An integer between 0 and 5 specifying the number of most likely tokens to return at each token + * position, each with an associated log probability. 'logprobs' must be set to 'true' if this + * parameter is used. + */ + private @JsonProperty("top_logprobs") Integer topLogprobs; + + /** + * The maximum number of tokens to generate in the chat completion. The total length of input + * tokens and generated tokens is limited by the model's context length. + * + *

Model-specific usage: + * + *

    + *
  • Use for non-reasoning models (e.g., gpt-4o, gpt-3.5-turbo) + *
  • Cannot be used with reasoning models (e.g., o1, o3, o4-mini series) + *
+ * + *

Mutual exclusivity: This parameter cannot be used together with {@link + * #maxCompletionTokens}. Setting both will result in an API error. + */ + private @JsonProperty("max_tokens") Integer maxTokens; + + /** + * An upper bound for the number of tokens that can be generated for a completion, including + * visible output tokens and reasoning tokens. + * + *

Model-specific usage: + * + *

    + *
  • Required for reasoning models (e.g., o1, o3, o4-mini series) + *
  • Cannot be used with non-reasoning models (e.g., gpt-4o, gpt-3.5-turbo) + *
+ * + *

Mutual exclusivity: This parameter cannot be used together with {@link + * #maxTokens}. Setting both will result in an API error. + */ + private @JsonProperty("max_completion_tokens") Integer maxCompletionTokens; + + /** + * How many chat completion choices to generate for each input message. Note that you will be + * charged based on the number of generated tokens across all of the choices. Keep n as 1 to + * minimize costs. + */ + private @JsonProperty("n") Integer n; + + /** + * Output types that you would like the model to generate for this request. Most models are + * capable of generating text, which is the default. The gpt-4o-audio-preview model can also be + * used to generate audio. To request that this model generate both text and audio responses, you + * can use: ["text", "audio"]. Note that the audio modality is only available for the + * gpt-4o-audio-preview model and is not supported for streaming completions. + */ + private @JsonProperty("modalities") List outputModalities; + + /** + * Audio parameters for the audio generation. Required when audio output is requested with + * modalities: ["audio"] Note: that the audio modality is only available for the + * gpt-4o-audio-preview model and is not supported for streaming completions. + */ + private @JsonProperty("audio") AudioParameters outputAudio; + + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear + * in the text so far, increasing the model's likelihood to talk about new topics. + */ + private @JsonProperty("presence_penalty") Double presencePenalty; + + /** + * An object specifying the format that the model must output. Setting to { "type": "json_object" + * } enables JSON mode, which guarantees the message the model generates is valid JSON. + */ + private @JsonProperty("response_format") ResponseFormat responseFormat; + + /** + * Options for streaming response. Included in the API only if streaming-mode completion is + * requested. + */ + private @JsonProperty("stream_options") StreamOptions streamOptions; + + /** + * This feature is in Beta. If specified, our system will make a best effort to sample + * deterministically, such that repeated requests with the same seed and parameters should return + * the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint + * response parameter to monitor changes in the backend. + */ + private @JsonProperty("seed") Integer seed; + + /** Up to 4 sequences where the API will stop generating further tokens. */ + private @JsonProperty("stop") List stop; + + /** + * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output + * more random, while lower values like 0.2 will make it more focused and deterministic. We + * generally recommend altering this or top_p but not both. + */ + private @JsonProperty("temperature") Double temperature; + + /** + * An alternative to sampling with temperature, called nucleus sampling, where the model considers + * the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising + * the top 10% probability mass are considered. We generally recommend altering this or + * temperature but not both. + */ + private @JsonProperty("top_p") Double topP; + + /** + * A list of tools the model may call. Currently, only functions are supported as a tool. Use this + * to provide a list of functions the model may generate JSON inputs for. + */ + private @JsonProperty("tools") List tools; + + /** + * Controls which (if any) function is called by the model. none means the model will not call a + * function and instead generates a message. auto means the model can pick between generating a + * message or calling a function. Specifying a particular function via {"type: "function", + * "function": {"name": "my_function"}} forces the model to call that function. none is the + * default when no functions are present. auto is the default if functions are present. Use the + * {@link ToolChoiceBuilder} to create a tool choice object. + */ + private @JsonProperty("tool_choice") Object toolChoice; + + /** + * A unique identifier representing your end-user, which can help OpenAI to monitor and detect + * abuse. + */ + private @JsonProperty("user") String user; + + /** + * Whether to enable parallel + * function calling during tool use. Defaults to true. + */ + private @JsonProperty("parallel_tool_calls") Boolean parallelToolCalls; + + /** + * Whether to store the output of this chat completion request for use in our model distillation or evals products. + */ + private @JsonProperty("store") Boolean store; + + /** + * Developer-defined tags and values used for filtering completions in the dashboard. + */ + private @JsonProperty("metadata") Map metadata; + + /** + * Constrains effort on reasoning for reasoning models. Currently supported values are low, + * medium, and high. Reducing reasoning effort can result in faster responses and fewer tokens + * used on reasoning in a response. Optional. Defaults to medium. Only for 'o1' models. + */ + private @JsonProperty("reasoning_effort") String reasoningEffort; + + /** + * verbosity: string or null Optional - Defaults to medium Constrains the verbosity of the model's + * response. Lower values will result in more concise responses, while higher values will result + * in more verbose responses. Currently supported values are low, medium, and high. If specified, + * the model will use web search to find relevant information to answer the user's question. + */ + private @JsonProperty("verbosity") String verbosity; + + /** This tool searches the web for relevant results to use in a response. */ + private @JsonProperty("web_search_options") WebSearchOptions webSearchOptions; + + /** + * Specifies the processing + * type used for serving the request. + */ + private @JsonProperty("service_tier") String serviceTier; + + /** + * Collection of {@link ToolCallback}s to be used for tool calling in the chat completion + * requests. + */ + @JsonIgnore private List toolCallbacks = new ArrayList<>(); + + /** + * Collection of tool names to be resolved at runtime and used for tool calling in the chat + * completion requests. + */ + @JsonIgnore private Set toolNames = new HashSet<>(); + + /** Whether to enable the tool execution lifecycle internally in ChatModel. */ + @JsonIgnore private Boolean internalToolExecutionEnabled; + + /** Optional HTTP headers to be added to the chat completion request. */ + @JsonIgnore private Map httpHeaders = new HashMap<>(); + + @JsonIgnore private Map toolContext = new HashMap<>(); + + private @JsonProperty("extra_body") Map extraBody; + + // @formatter:on public static Builder builder() { return new Builder(); @@ -306,6 +321,7 @@ public static OpenAiChatOptions fromOptions(OpenAiChatOptions fromOptions) { .webSearchOptions(fromOptions.getWebSearchOptions()) .verbosity(fromOptions.getVerbosity()) .serviceTier(fromOptions.getServiceTier()) + .extraBody(fromOptions.getExtraBody()) .build(); } @@ -502,6 +518,14 @@ public void setParallelToolCalls(Boolean parallelToolCalls) { this.parallelToolCalls = parallelToolCalls; } + public Map getExtraBody() { + return extraBody; + } + + public void setExtraBody(Map extraBody) { + this.extraBody = extraBody; + } + @Override @JsonIgnore public List getToolCallbacks() { @@ -722,19 +746,17 @@ public Builder topLogprobs(Integer topLogprobs) { * *

* Model-specific usage: - *

+ * *
    - *
  • Use for non-reasoning models (e.g., gpt-4o, - * gpt-3.5-turbo)
  • + *
  • Use for non-reasoning models (e.g., gpt-4o, gpt-3.5-turbo) *
  • Cannot be used with reasoning models (e.g., o1, o3, - * o4-mini series)
  • + * o4-mini series) *
* *

* Mutual exclusivity: This parameter cannot be used together * with {@link #maxCompletionTokens(Integer)}. If both are set, the last one set * will be used and the other will be cleared with a warning. - *

* @param maxTokens the maximum number of tokens to generate, or null to unset * @return this builder instance */ @@ -756,19 +778,18 @@ public Builder maxTokens(Integer maxTokens) { * *

* Model-specific usage: - *

+ * *
    *
  • Required for reasoning models (e.g., o1, o3, o4-mini - * series)
  • + * series) *
  • Cannot be used with non-reasoning models (e.g., gpt-4o, - * gpt-3.5-turbo)
  • + * gpt-3.5-turbo) *
* *

* Mutual exclusivity: This parameter cannot be used together * with {@link #maxTokens(Integer)}. If both are set, the last one set will be * used and the other will be cleared with a warning. - *

* @param maxCompletionTokens the maximum number of completion tokens to generate, * or null to unset * @return this builder instance @@ -933,6 +954,11 @@ public Builder serviceTier(OpenAiApi.ServiceTier serviceTier) { return this; } + public Builder extraBody(Map extraBody) { + this.options.extraBody = extraBody; + return this; + } + public OpenAiChatOptions build() { return this.options; } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index c08cae71054..ff49caf90b8 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -16,22 +16,18 @@ package org.springframework.ai.openai.api; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; -import java.util.function.Predicate; -import java.util.stream.Collectors; - import com.fasterxml.jackson.annotation.JsonFormat; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.stream.Collectors; import org.springframework.ai.model.ApiKey; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; @@ -50,6 +46,8 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; /** * Single class implementation of the @@ -1086,36 +1084,38 @@ public enum OutputModality { * @param verbosity Controls the verbosity of the model's response. */ @JsonInclude(Include.NON_NULL) - public record ChatCompletionRequest(// @formatter:off - @JsonProperty("messages") List messages, - @JsonProperty("model") String model, - @JsonProperty("store") Boolean store, - @JsonProperty("metadata") Map metadata, - @JsonProperty("frequency_penalty") Double frequencyPenalty, - @JsonProperty("logit_bias") Map logitBias, - @JsonProperty("logprobs") Boolean logprobs, - @JsonProperty("top_logprobs") Integer topLogprobs, - @JsonProperty("max_tokens") Integer maxTokens, // original field for specifying token usage. - @JsonProperty("max_completion_tokens") Integer maxCompletionTokens, // new field for gpt-o1 and other reasoning models - @JsonProperty("n") Integer n, - @JsonProperty("modalities") List outputModalities, - @JsonProperty("audio") AudioParameters audioParameters, - @JsonProperty("presence_penalty") Double presencePenalty, - @JsonProperty("response_format") ResponseFormat responseFormat, - @JsonProperty("seed") Integer seed, - @JsonProperty("service_tier") String serviceTier, - @JsonProperty("stop") List stop, - @JsonProperty("stream") Boolean stream, - @JsonProperty("stream_options") StreamOptions streamOptions, - @JsonProperty("temperature") Double temperature, - @JsonProperty("top_p") Double topP, - @JsonProperty("tools") List tools, - @JsonProperty("tool_choice") Object toolChoice, - @JsonProperty("parallel_tool_calls") Boolean parallelToolCalls, - @JsonProperty("user") String user, - @JsonProperty("reasoning_effort") String reasoningEffort, - @JsonProperty("web_search_options") WebSearchOptions webSearchOptions, - @JsonProperty("verbosity") String verbosity) { + public record ChatCompletionRequest( // @formatter:off + @JsonProperty("messages") List messages, + @JsonProperty("model") String model, + @JsonProperty("store") Boolean store, + @JsonProperty("metadata") Map metadata, + @JsonProperty("frequency_penalty") Double frequencyPenalty, + @JsonProperty("logit_bias") Map logitBias, + @JsonProperty("logprobs") Boolean logprobs, + @JsonProperty("top_logprobs") Integer topLogprobs, + @JsonProperty("max_tokens") Integer maxTokens, // original field for specifying token usage. + @JsonProperty("max_completion_tokens") + Integer maxCompletionTokens, // new field for gpt-o1 and other reasoning models + @JsonProperty("n") Integer n, + @JsonProperty("modalities") List outputModalities, + @JsonProperty("audio") AudioParameters audioParameters, + @JsonProperty("presence_penalty") Double presencePenalty, + @JsonProperty("response_format") ResponseFormat responseFormat, + @JsonProperty("seed") Integer seed, + @JsonProperty("service_tier") String serviceTier, + @JsonProperty("stop") List stop, + @JsonProperty("stream") Boolean stream, + @JsonProperty("stream_options") StreamOptions streamOptions, + @JsonProperty("temperature") Double temperature, + @JsonProperty("top_p") Double topP, + @JsonProperty("tools") List tools, + @JsonProperty("tool_choice") Object toolChoice, + @JsonProperty("parallel_tool_calls") Boolean parallelToolCalls, + @JsonProperty("user") String user, + @JsonProperty("reasoning_effort") String reasoningEffort, + @JsonProperty("web_search_options") WebSearchOptions webSearchOptions, + @JsonProperty("verbosity") String verbosity, + @JsonProperty("extra_body") Map extraBody) { /** * Shortcut constructor for a chat completion request with the given messages, model and temperature. @@ -1125,9 +1125,37 @@ public record ChatCompletionRequest(// @formatter:off * @param temperature What sampling temperature to use, between 0 and 1. */ public ChatCompletionRequest(List messages, String model, Double temperature) { - this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, - null, null, null, false, null, temperature, null, - null, null, null, null, null, null, null); + this( + messages, + model, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + false, + null, + temperature, + null, + null, + null, + null, + null, + null, + null, + null, + null); } /** @@ -1138,10 +1166,37 @@ public ChatCompletionRequest(List messages, String model, * @param audio Parameters for audio output. Required when audio output is requested with outputModalities: ["audio"]. */ public ChatCompletionRequest(List messages, String model, AudioParameters audio, boolean stream) { - this(messages, model, null, null, null, null, null, null, - null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null, - null, null, null, stream, null, null, null, - null, null, null, null, null, null, null); + this( + messages, + model, + null, + null, + null, + null, + null, + null, + null, + null, + null, + List.of(OutputModality.AUDIO, OutputModality.TEXT), + audio, + null, + null, + null, + null, + null, + stream, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null); } /** @@ -1154,9 +1209,37 @@ public ChatCompletionRequest(List messages, String model, * as they become available, with the stream terminated by a data: [DONE] message. */ public ChatCompletionRequest(List messages, String model, Double temperature, boolean stream) { - this(messages, model, null, null, null, null, null, null, null, null, null, - null, null, null, null, null, null, null, stream, null, temperature, null, - null, null, null, null, null, null, null); + this( + messages, + model, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + stream, + null, + temperature, + null, + null, + null, + null, + null, + null, + null, + null, + null); } /** @@ -1170,9 +1253,37 @@ public ChatCompletionRequest(List messages, String model, */ public ChatCompletionRequest(List messages, String model, List tools, Object toolChoice) { - this(messages, model, null, null, null, null, null, null, null, null, null, - null, null, null, null, null, null, null, false, null, 0.8, null, - tools, toolChoice, null, null, null, null, null); + this( + messages, + model, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + false, + null, + 0.8, + null, + tools, + toolChoice, + null, + null, + null, + null, + null, + null); } /** @@ -1183,9 +1294,10 @@ public ChatCompletionRequest(List messages, String model, * as they become available, with the stream terminated by a data: [DONE] message. */ public ChatCompletionRequest(List messages, Boolean stream) { - this(messages, null, null, null, null, null, null, null, null, null, null, - null, null, null, null, null, null, null, stream, null, null, null, - null, null, null, null, null, null, null); + this( + messages, null, null, null, null, null, null, null, null, null, null, null, null, null, + null, null, null, null, stream, null, null, null, null, null, null, null, null, null, + null, null); } /** @@ -1195,10 +1307,37 @@ public ChatCompletionRequest(List messages, Boolean strea * @return A new {@link ChatCompletionRequest} with the specified stream options. */ public ChatCompletionRequest streamOptions(StreamOptions streamOptions) { - return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs, - this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty, - this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP, - this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions, this.verbosity); + return new ChatCompletionRequest( + this.messages, + this.model, + this.store, + this.metadata, + this.frequencyPenalty, + this.logitBias, + this.logprobs, + this.topLogprobs, + this.maxTokens, + this.maxCompletionTokens, + this.n, + this.outputModalities, + this.audioParameters, + this.presencePenalty, + this.responseFormat, + this.seed, + this.serviceTier, + this.stop, + this.stream, + streamOptions, + this.temperature, + this.topP, + this.tools, + this.toolChoice, + this.parallelToolCalls, + this.user, + this.reasoningEffort, + this.webSearchOptions, + this.verbosity, + this.extraBody); } /** @@ -1402,16 +1541,18 @@ public String getValue() { */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record ChatCompletionMessage(// @formatter:off - @JsonProperty("content") Object rawContent, - @JsonProperty("role") Role role, - @JsonProperty("name") String name, - @JsonProperty("tool_call_id") String toolCallId, - @JsonProperty("tool_calls") @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List toolCalls, - @JsonProperty("refusal") String refusal, - @JsonProperty("audio") AudioOutput audioOutput, - @JsonProperty("annotations") List annotations - ) { // @formatter:on + public record ChatCompletionMessage( // @formatter:off + @JsonProperty("content") Object rawContent, + @JsonProperty("role") Role role, + @JsonProperty("name") String name, + @JsonProperty("tool_call_id") String toolCallId, + @JsonProperty("tool_calls") + @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) + List toolCalls, + @JsonProperty("refusal") String refusal, + @JsonProperty("audio") AudioOutput audioOutput, + @JsonProperty("annotations") List annotations, + @JsonProperty("reasoning_content") String reasoningContent) { // @formatter:on /** * Create a chat completion message with the given content and role. All other @@ -1420,7 +1561,7 @@ public record ChatCompletionMessage(// @formatter:off * @param role The role of the author of this message. */ public ChatCompletionMessage(Object content, Role role) { - this(content, role, null, null, null, null, null, null); + this(content, role, null, null, null, null, null, null, null); } /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java index d8fcb056f1f..0628acb4ab6 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java @@ -18,7 +18,6 @@ import java.util.ArrayList; import java.util.List; - import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; @@ -36,6 +35,7 @@ /** * Helper class to support Streaming function calling. * + *

* It can merge the streamed ChatCompletionChunk in case of function calling message. * * @author Christian Tzolov @@ -100,6 +100,8 @@ private ChunkChoice merge(ChunkChoice previous, ChunkChoice current) { private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompletionMessage current) { String content = (current.content() != null ? current.content() : "" + ((previous.content() != null) ? previous.content() : "")); + String reasoningContent = (current.reasoningContent() != null ? current.reasoningContent() + : "" + ((previous.reasoningContent() != null) ? previous.reasoningContent() : "")); Role role = (current.role() != null ? current.role() : previous.role()); role = (role != null ? role : Role.ASSISTANT); // default to ASSISTANT (if null String name = (current.name() != null ? current.name() : previous.name()); @@ -138,7 +140,8 @@ private ChatCompletionMessage merge(ChatCompletionMessage previous, ChatCompleti toolCalls.add(lastPreviousTooCall); } } - return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput, annotations); + return new ChatCompletionMessage(content, role, name, toolCallId, toolCalls, refusal, audioOutput, annotations, + reasoningContent); } private ToolCall merge(ToolCall previous, ToolCall current) { diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java index 72329d3aa88..57281336ec4 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java @@ -16,12 +16,15 @@ package org.springframework.ai.openai.api; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + import java.io.IOException; import java.util.LinkedList; import java.util.List; import java.util.Objects; import java.util.Queue; - import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -30,7 +33,6 @@ import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.opentest4j.AssertionFailedError; - import org.springframework.ai.model.ApiKey; import org.springframework.ai.model.SimpleApiKey; import org.springframework.http.HttpHeaders; @@ -43,10 +45,6 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.mock; - public class OpenAiApiBuilderTests { private static final ApiKey TEST_API_KEY = new SimpleApiKey("test-api-key"); @@ -296,6 +294,7 @@ void dynamicApiKeyWebClient() throws InterruptedException { "role": "assistant", "content": "Hello world" }, + "reasoning_content": "test", "finish_reason": "stop" } ], diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java index d050a621034..12c3bba3924 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java @@ -16,17 +16,17 @@ package org.springframework.ai.openai.api; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + import java.io.IOException; import java.util.Base64; import java.util.List; - import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; -import reactor.core.publisher.Flux; - import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; @@ -36,9 +36,7 @@ import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList; import org.springframework.core.io.ClassPathResource; import org.springframework.http.ResponseEntity; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; +import reactor.core.publisher.Flux; /** * @author Christian Tzolov @@ -77,7 +75,7 @@ void validateReasoningTokens() { "If a train travels 100 miles in 2 hours, what is its average speed?", ChatCompletionMessage.Role.USER); ChatCompletionRequest request = new ChatCompletionRequest(List.of(userMessage), "o1", null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, null, null, null, - null, null, null, "low", null, null); + null, null, null, "low", null, null, null); ResponseEntity response = this.openAiApi.chatCompletionEntity(request); assertThat(response).isNotNull(); @@ -180,7 +178,7 @@ void chatCompletionEntityWithNewModelsAndLowVerbosity(OpenAiApi.ChatModel modelN ChatCompletionRequest request = new ChatCompletionRequest(List.of(chatCompletionMessage), // messages modelName.getValue(), null, null, null, null, null, null, null, null, null, null, null, null, null, - null, null, null, false, null, 1.0, null, null, null, null, null, null, null, "low"); + null, null, null, false, null, 1.0, null, null, null, null, null, null, null, "low", null); ResponseEntity response = this.openAiApi.chatCompletionEntity(request); @@ -227,7 +225,7 @@ void chatCompletionEntityWithServiceTier(OpenAiApi.ServiceTier serviceTier) { ChatCompletionRequest request = new ChatCompletionRequest(List.of(chatCompletionMessage), // messages OpenAiApi.ChatModel.GPT_4_O.value, null, null, null, null, null, null, null, null, null, null, null, null, null, null, serviceTier.getValue(), null, false, null, 1.0, null, null, null, null, null, null, - null, null); + null, null, null); ResponseEntity response = this.openAiApi.chatCompletionEntity(request); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java index 23fcf704fdb..7256aae459a 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java @@ -16,16 +16,15 @@ package org.springframework.ai.openai.api; +import static org.assertj.core.api.Assertions.assertThat; + import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.function.Consumer; - import org.junit.jupiter.api.Test; import org.mockito.Mockito; -import static org.assertj.core.api.Assertions.assertThat; - /** * Unit tests for {@link OpenAiStreamFunctionCallingHelper} * @@ -87,8 +86,8 @@ public void isStreamingToolFunctionCall_whenChatCompletionChunkFirstChoiceDeltaT // Test for null. assertion.accept(new OpenAiApi.ChatCompletionMessage(null, null)); // Test for empty. - assertion.accept( - new OpenAiApi.ChatCompletionMessage(null, null, null, null, Collections.emptyList(), null, null, null)); + assertion.accept(new OpenAiApi.ChatCompletionMessage(null, null, null, null, Collections.emptyList(), null, + null, null, null)); } @Test @@ -101,7 +100,7 @@ public void isStreamingToolFunctionCall_whenChatCompletionChunkFirstChoiceDeltaT }; assertion.accept(new OpenAiApi.ChatCompletionMessage(null, null, null, null, List.of(Mockito.mock(org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall.class)), - null, null, null)); + null, null, null, null)); } @Test @@ -190,7 +189,8 @@ public void isStreamingToolFunctionCallReturnsFalseForNullOrEmptyChunks() { @Test public void isStreamingToolFunctionCall_returnsTrueForValidToolCalls() { var toolCall = Mockito.mock(OpenAiApi.ChatCompletionMessage.ToolCall.class); - var delta = new OpenAiApi.ChatCompletionMessage(null, null, null, null, List.of(toolCall), null, null, null); + var delta = new OpenAiApi.ChatCompletionMessage(null, null, null, null, List.of(toolCall), null, null, null, + null); var choice = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, delta, null); var chunk = new OpenAiApi.ChatCompletionChunk(null, List.of(choice), null, null, null, null, null, null); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java index e655de46421..e4ddf7b42aa 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java @@ -16,16 +16,16 @@ package org.springframework.ai.openai.api.tool; -import java.util.ArrayList; -import java.util.List; +import static org.assertj.core.api.Assertions.assertThat; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; +import java.util.ArrayList; +import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; @@ -36,8 +36,6 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder; import org.springframework.http.ResponseEntity; -import static org.assertj.core.api.Assertions.assertThat; - /** * Based on the OpenAI Function Calling tutorial: * https://platform.openai.com/docs/guides/function-calling/parallel-function-calling @@ -130,7 +128,7 @@ public void toolFunctionCall() { // extend conversation with function response. messages.add(new ChatCompletionMessage("" + weatherResponse.temp() + weatherRequest.unit(), Role.TOOL, - functionName, toolCall.id(), null, null, null, null)); + functionName, toolCall.id(), null, null, null, null, null)); } } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java index e19e82640b2..766f6a101b9 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java @@ -16,9 +16,14 @@ package org.springframework.ai.openai.chat; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.BDDMockito.given; + import java.util.List; import java.util.Optional; - import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -27,8 +32,6 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import reactor.core.publisher.Flux; - import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; import org.springframework.ai.chat.prompt.Prompt; @@ -70,12 +73,7 @@ import org.springframework.retry.RetryContext; import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.BDDMockito.given; +import reactor.core.publisher.Flux; /** * @author Christian Tzolov diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java index 3dc59444e82..a35fce716ca 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java @@ -16,16 +16,18 @@ package org.springframework.ai.openai.chat; -import java.util.List; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.BDDMockito.given; import com.fasterxml.jackson.core.JsonProcessingException; import io.micrometer.observation.ObservationRegistry; +import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import reactor.core.publisher.Flux; - import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; @@ -41,11 +43,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.retry.RetryUtils; import org.springframework.retry.support.RetryTemplate; - -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.BDDMockito.given; +import reactor.core.publisher.Flux; /** * Tests for OpenAI streaming responses with various finish_reason scenarios, particularly @@ -124,7 +122,8 @@ void testJsonDeserializationWithEmptyStringFinishReason() throws JsonProcessingE "index": 0, "delta": { "role": "assistant", - "content": "" + "content": "", + "reasoning_content": "" }, "finish_reason": "" }] @@ -161,7 +160,8 @@ void testJsonDeserializationWithNullFinishReason() throws JsonProcessingExceptio "index": 0, "delta": { "role": "assistant", - "content": "Hello" + "content": "Hello", + "reasoning_content": "test" }, "finish_reason": null }] @@ -176,6 +176,7 @@ void testJsonDeserializationWithNullFinishReason() throws JsonProcessingExceptio var choice = chunk.choices().get(0); assertThat(choice.finishReason()).isNull(); assertThat(choice.delta().content()).isEqualTo("Hello"); + assertThat(choice.delta().reasoningContent()).isEqualTo("test"); } @Test From bbcb96854b5ae2ff95c9607f6d92c5436e4f7a3e Mon Sep 17 00:00:00 2001 From: SenreySong <25841017+SenreySong@users.noreply.github.com> Date: Thu, 4 Sep 2025 17:02:38 +0800 Subject: [PATCH 2/5] Fix equals and hashCode Signed-off-by: SenreySong <25841017+SenreySong@users.noreply.github.com> --- .../org/springframework/ai/openai/OpenAiChatOptions.java | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index d3da58b346c..0331b8d20eb 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -654,7 +654,8 @@ public int hashCode() { this.streamOptions, this.seed, this.stop, this.temperature, this.topP, this.tools, this.toolChoice, this.user, this.parallelToolCalls, this.toolCallbacks, this.toolNames, this.httpHeaders, this.internalToolExecutionEnabled, this.toolContext, this.outputModalities, this.outputAudio, - this.store, this.metadata, this.reasoningEffort, this.webSearchOptions, this.serviceTier); + this.store, this.metadata, this.reasoningEffort, this.webSearchOptions, this.serviceTier, + this.extraBody); } @Override @@ -689,7 +690,8 @@ public boolean equals(Object o) { && Objects.equals(this.reasoningEffort, other.reasoningEffort) && Objects.equals(this.webSearchOptions, other.webSearchOptions) && Objects.equals(this.verbosity, other.verbosity) - && Objects.equals(this.serviceTier, other.serviceTier); + && Objects.equals(this.serviceTier, other.serviceTier) + && Objects.equals(this.extraBody, other.extraBody); } @Override From bbf0ad23dac74c5b87e7cde340dd6a0df88e789b Mon Sep 17 00:00:00 2001 From: SenreySong <25841017+SenreySong@users.noreply.github.com> Date: Mon, 8 Sep 2025 18:46:36 +0800 Subject: [PATCH 3/5] Fix dynamic parameter adaptation for extra_body - Updated streaming chat response logic to use dynamic request body creation. Signed-off-by: SenreySong <25841017+SenreySong@users.noreply.github.com> --- .../ai/openai/api/OpenAiApi.java | 58 +++++++++++++++---- 1 file changed, 48 insertions(+), 10 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 55868753045..0dea0e6e291 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -22,13 +22,16 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; +import com.fasterxml.jackson.databind.node.ObjectNode; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Predicate; import java.util.stream.Collectors; -import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import org.springframework.ai.model.ApiKey; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; @@ -185,6 +188,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest Assert.isTrue(!chatRequest.stream(), "Request must set the stream property to false."); Assert.notNull(additionalHttpHeader, "The additional HTTP headers can not be null."); + Object dynamicRequestBody = createDynamicRequestBody(chatRequest); // @formatter:off return this.restClient.post() .uri(this.completionsPath) @@ -192,7 +196,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest headers.addAll(additionalHttpHeader); addDefaultHeadersIfMissing(headers); }) - .body(chatRequest) + .body(dynamicRequestBody) .retrieve() .toEntity(ChatCompletion.class); // @formatter:on @@ -208,6 +212,29 @@ public Flux chatCompletionStream(ChatCompletionRequest chat return chatCompletionStream(chatRequest, new LinkedMultiValueMap<>()); } + private Object createDynamicRequestBody(ChatCompletionRequest baseRequest) { + ObjectMapper mapper = new ObjectMapper(); + ObjectNode requestNode = mapper.valueToTree(baseRequest); + if (null == baseRequest.extraBody) { + return requestNode; + } + + // 添加额外字段 + baseRequest.extraBody().forEach((key, value) -> { + if (value instanceof Map) { + requestNode.set(key, mapper.valueToTree(value)); + } + else if (value instanceof List) { + requestNode.set(key, mapper.valueToTree(value)); + } + else { + requestNode.putPOJO(key, value); + } + }); + + return requestNode; + } + /** * Creates a streaming chat response for the given chat conversation. * @param chatRequest The chat completion request. Must have the stream property set @@ -224,14 +251,25 @@ public Flux chatCompletionStream(ChatCompletionRequest chat AtomicBoolean isInsideTool = new AtomicBoolean(false); - // @formatter:off - return this.webClient.post() - .uri(this.completionsPath) - .headers(headers -> { - headers.addAll(additionalHttpHeader); - addDefaultHeadersIfMissing(headers); - }) // @formatter:on - .body(Mono.just(chatRequest), ChatCompletionRequest.class) + ObjectMapper objectMapper = new ObjectMapper(); + try { + var s = objectMapper.writeValueAsString(chatRequest); + System.out.println("aaaaaaaa:" + s); + } + catch (JsonProcessingException e) { + throw new RuntimeException(e); + } + Object dynamicBody = createDynamicRequestBody(chatRequest); + // @formatter:off + return this.webClient + .post() + .uri(this.completionsPath) + .headers( + headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) // @formatter:on + .bodyValue(dynamicBody) .retrieve() .bodyToFlux(String.class) // cancels the flux stream after the "[DONE]" is received. From 94867e0f033f73ec13d87c9340e6f0c6dc32536a Mon Sep 17 00:00:00 2001 From: Senrey_Song <25841017+senreysong@users.noreply.github.com> Date: Tue, 16 Sep 2025 10:29:19 +0800 Subject: [PATCH 4/5] refactor(spring-ai-openai): fix code style Signed-off-by: SenreySong <25841017+SenreySong@users.noreply.github.com> refactor(spring-ai-openai): fix import style Signed-off-by: SenreySong <25841017+SenreySong@users.noreply.github.com> refactor(spring-ai-openai): fix code style Signed-off-by: SenreySong <25841017+SenreySong@users.noreply.github.com> refactor(spring-ai-openai): fix code style Signed-off-by: SenreySong <25841017+SenreySong@users.noreply.github.com> --- .../ai/openai/OpenAiChatModel.java | 221 ++++----- .../ai/openai/OpenAiChatOptions.java | 454 +++++++++--------- .../ai/openai/api/OpenAiApi.java | 301 +++--------- .../OpenAiStreamFunctionCallingHelper.java | 1 + .../ai/openai/api/OpenAiApiBuilderTests.java | 10 +- .../ai/openai/api/OpenAiApiIT.java | 10 +- ...OpenAiStreamFunctionCallingHelperTest.java | 9 +- .../api/tool/OpenAiApiToolFunctionCallIT.java | 8 +- .../ai/openai/chat/OpenAiRetryTests.java | 16 +- .../OpenAiStreamingFinishReasonTests.java | 14 +- 10 files changed, 436 insertions(+), 608 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 7b73e668fa5..85b06a5a4d6 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -16,9 +16,6 @@ package org.springframework.ai.openai; -import io.micrometer.observation.Observation; -import io.micrometer.observation.ObservationRegistry; -import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import java.util.ArrayList; import java.util.Base64; import java.util.HashMap; @@ -26,8 +23,16 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.stream.Collectors; + +import io.micrometer.observation.Observation; +import io.micrometer.observation.ObservationRegistry; +import io.micrometer.observation.contextpropagation.ObservationThreadLocalAccessor; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.core.scheduler.Schedulers; + import org.springframework.ai.chat.messages.AssistantMessage; import org.springframework.ai.chat.messages.MessageType; import org.springframework.ai.chat.messages.ToolResponseMessage; @@ -81,9 +86,6 @@ import org.springframework.util.MimeTypeUtils; import org.springframework.util.MultiValueMap; import org.springframework.util.StringUtils; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.core.scheduler.Schedulers; /** * {@link ChatModel} and {@link StreamingChatModel} implementation for {@literal OpenAI} @@ -116,16 +118,24 @@ public class OpenAiChatModel implements ChatModel { private static final ToolCallingManager DEFAULT_TOOL_CALLING_MANAGER = ToolCallingManager.builder().build(); - /** The default options used for the chat completion requests. */ + /** + * The default options used for the chat completion requests. + */ private final OpenAiChatOptions defaultOptions; - /** The retry template used to retry the OpenAI API calls. */ + /** + * The retry template used to retry the OpenAI API calls. + */ private final RetryTemplate retryTemplate; - /** Low-level access to the OpenAI API. */ + /** + * Low-level access to the OpenAI API. + */ private final OpenAiApi openAiApi; - /** Observation registry used for instrumentation. */ + /** + * Observation registry used for instrumentation. + */ private final ObservationRegistry observationRegistry; private final ToolCallingManager toolCallingManager; @@ -136,7 +146,9 @@ public class OpenAiChatModel implements ChatModel { */ private final ToolExecutionEligibilityPredicate toolExecutionEligibilityPredicate; - /** Conventions to use for generating observations. */ + /** + * Conventions to use for generating observations. + */ private ChatModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; public OpenAiChatModel(OpenAiApi openAiApi, OpenAiChatOptions defaultOptions, ToolCallingManager toolCallingManager, @@ -183,6 +195,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { + ResponseEntity completionEntity = this.retryTemplate .execute(ctx -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt))); @@ -200,32 +213,17 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons } // @formatter:off - List generations = - choices.stream() - .map( - choice -> { - Map metadata = - Map.of( - "id", - chatCompletion.id() != null ? chatCompletion.id() : "", - "role", - choice.message().role() != null - ? choice.message().role().name() - : "", - "index", choice.index() != null ? choice.index() : 0, - "finishReason", getFinishReasonJson(choice.finishReason()), - "refusal", - StringUtils.hasText(choice.message().refusal()) - ? choice.message().refusal() - : "", - "annotations", - choice.message().annotations() != null - ? choice.message().annotations() - : List.of(Map.of())); - return buildGeneration(choice, metadata, request); - }) - .toList(); - // @formatter:on + List generations = choices.stream().map(choice -> { + Map metadata = Map.of( + "id", chatCompletion.id() != null ? chatCompletion.id() : "", + "role", choice.message().role() != null ? choice.message().role().name() : "", + "index", choice.index() != null ? choice.index() : 0, + "finishReason", getFinishReasonJson(choice.finishReason()), + "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", + "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of(Map.of())); + return buildGeneration(choice, metadata, request); + }).toList(); + // @formatter:on RateLimit rateLimit = OpenAiResponseHeaderExtractor.extractAiResponseHeaders(completionEntity); @@ -240,6 +238,7 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons observationContext.setResponse(chatResponse); return chatResponse; + }); if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { @@ -307,49 +306,24 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha Flux chatResponse = completionChunks.map(this::chunkToChatCompletion) .switchMap(chatCompletion -> Mono.just(chatCompletion).map(chatCompletion2 -> { try { - // If an id is not provided, set to "NO_ID" (for compatible - // APIs). + // If an id is not provided, set to "NO_ID" (for compatible APIs). String id = chatCompletion2.id() == null ? "NO_ID" : chatCompletion2.id(); List generations = chatCompletion2.choices().stream().map(choice -> { // @formatter:off - if (choice.message().role() != null) { - roleMap.putIfAbsent( - id, choice.message().role().name()); - } - Map metadata = - Map.of( - "id", - id, - "role", - roleMap.getOrDefault(id, ""), - "index", - choice.index() != null - ? choice.index() - : 0, - "finishReason", - getFinishReasonJson( - choice.finishReason()), - "refusal", - StringUtils.hasText( - choice.message().refusal()) - ? choice.message().refusal() - : "", - "annotations", - choice.message().annotations() != null - ? choice.message().annotations() - : List.of(), - "reasoningContent", - choice.message().reasoningContent() - != null - ? choice - .message() - .reasoningContent() - : ""); - return buildGeneration( - choice, metadata, request); - }) - .toList(); - // @formatter:on + if (choice.message().role() != null) { + roleMap.putIfAbsent(id, choice.message().role().name()); + } + Map metadata = Map.of( + "id", id, + "role", roleMap.getOrDefault(id, ""), + "index", choice.index() != null ? choice.index() : 0, + "finishReason", getFinishReasonJson(choice.finishReason()), + "refusal", StringUtils.hasText(choice.message().refusal()) ? choice.message().refusal() : "", + "annotations", choice.message().annotations() != null ? choice.message().annotations() : List.of(), + "reasoningContent", choice.message().reasoningContent() != null ? choice.message().reasoningContent() : ""); + return buildGeneration(choice, metadata, request); + }).toList(); + // @formatter:on OpenAiApi.Usage usage = chatCompletion2.usage(); Usage currentChatResponseUsage = usage != null ? getDefaultUsage(usage) : new EmptyUsage(); Usage accumulatedUsage = UsageCalculator.getCumulativeUsage(currentChatResponseUsage, @@ -360,8 +334,7 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha logger.error("Error processing chat completion", e); return new ChatResponse(List.of()); } - // When in stream mode and enabled to include the usage, the - // OpenAI + // When in stream mode and enabled to include the usage, the OpenAI // Chat completion response would have the usage set only in its // final response. Hence, the following overlapping buffer is // created to store both the current and the subsequent response @@ -390,54 +363,43 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha }); // @formatter:off - Flux flux = - chatResponse - .flatMap( - response -> { - if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired( - prompt.getOptions(), response)) { - // FIXME: bounded elastic needs to be used since tool calling - // is currently only synchronous - return Flux.deferContextual( - ctx -> { - ToolExecutionResult toolExecutionResult; - try { - ToolCallReactiveContextHolder.setContext(ctx); - toolExecutionResult = - this.toolCallingManager.executeToolCalls( - prompt, response); - } finally { - ToolCallReactiveContextHolder.clearContext(); - } - if (toolExecutionResult.returnDirect()) { - // Return tool execution result directly to the client. - return Flux.just( - ChatResponse.builder() - .from(response) - .generations( - ToolExecutionResult.buildGenerations( - toolExecutionResult)) - .build()); - } else { - // Send the tool execution result back to the model. - return this.internalStream( - new Prompt( - toolExecutionResult.conversationHistory(), - prompt.getOptions()), - response); - } - }) - .subscribeOn(Schedulers.boundedElastic()); - } else { - return Flux.just(response); - } - }) - .doOnError(observation::error) - .doFinally(s -> observation.stop()) - .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); - // @formatter:on + Flux flux = chatResponse.flatMap(response -> { + if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { + // FIXME: bounded elastic needs to be used since tool calling + // is currently only synchronous + return Flux.deferContextual(ctx -> { + ToolExecutionResult toolExecutionResult; + try { + ToolCallReactiveContextHolder.setContext(ctx); + toolExecutionResult = this.toolCallingManager.executeToolCalls(prompt, response); + } + finally { + ToolCallReactiveContextHolder.clearContext(); + } + if (toolExecutionResult.returnDirect()) { + // Return tool execution result directly to the client. + return Flux.just(ChatResponse.builder().from(response) + .generations(ToolExecutionResult.buildGenerations(toolExecutionResult)) + .build()); + } + else { + // Send the tool execution result back to the model. + return this.internalStream(new Prompt(toolExecutionResult.conversationHistory(), prompt.getOptions()), + response); + } + }).subscribeOn(Schedulers.boundedElastic()); + } + else { + return Flux.just(response); + } + }) + .doOnError(observation::error) + .doFinally(s -> observation.stop()) + .contextWrite(ctx -> ctx.put(ObservationThreadLocalAccessor.KEY, observation)); + // @formatter:on return new MessageAggregator().aggregate(flux, observationContext::setResponse); + }); } @@ -607,7 +569,9 @@ private Map mergeHttpHeaders(Map runtimeHttpHead return mergedHttpHeaders; } - /** Accessible for testing. */ + /** + * Accessible for testing. + */ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { List chatCompletionMessages = prompt.getInstructions().stream().map(message -> { @@ -640,6 +604,7 @@ else if (message.getMessageType() == MessageType.ASSISTANT) { Assert.isTrue(assistantMessage.getMedia().size() == 1, "Only one media content is supported for assistant messages"); audioOutput = new AudioOutput(assistantMessage.getMedia().get(0).getId(), null, null, null); + } return List.of(new ChatCompletionMessage(assistantMessage.getText(), ChatCompletionMessage.Role.ASSISTANT, null, null, toolCalls, null, audioOutput, null, null)); @@ -756,7 +721,9 @@ public static Builder builder() { return new Builder(); } - /** Returns a builder pre-populated with the current configuration for mutation. */ + /** + * Returns a builder pre-populated with the current configuration for mutation. + */ public Builder mutate() { return new Builder(this); } diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index 0331b8d20eb..bdbb308d6ef 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -16,10 +16,6 @@ package org.springframework.ai.openai; -import com.fasterxml.jackson.annotation.JsonIgnore; -import com.fasterxml.jackson.annotation.JsonInclude; -import com.fasterxml.jackson.annotation.JsonInclude.Include; -import com.fasterxml.jackson.annotation.JsonProperty; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -28,8 +24,14 @@ import java.util.Map; import java.util.Objects; import java.util.Set; + +import com.fasterxml.jackson.annotation.JsonIgnore; +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.annotation.JsonInclude.Include; +import com.fasterxml.jackson.annotation.JsonProperty; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.model.tool.ToolCallingChatOptions; import org.springframework.ai.openai.api.OpenAiApi; @@ -58,228 +60,213 @@ public class OpenAiChatOptions implements ToolCallingChatOptions { private static final Logger logger = LoggerFactory.getLogger(OpenAiChatOptions.class); // @formatter:off - /** ID of the model to use. */ - private @JsonProperty("model") String model; - - /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing - * frequency in the text so far, decreasing the model's likelihood to repeat the same line - * verbatim. - */ - private @JsonProperty("frequency_penalty") Double frequencyPenalty; - - /** - * Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object - * that maps tokens (specified by their token ID in the tokenizer) to an associated bias value - * from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior - * to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease - * or increase likelihood of selection; values like -100 or 100 should result in a ban or - * exclusive selection of the relevant token. - */ - private @JsonProperty("logit_bias") Map logitBias; - - /** - * Whether to return log probabilities of the output tokens or not. If true, returns the log - * probabilities of each output token returned in the 'content' of 'message'. - */ - private @JsonProperty("logprobs") Boolean logprobs; - - /** - * An integer between 0 and 5 specifying the number of most likely tokens to return at each token - * position, each with an associated log probability. 'logprobs' must be set to 'true' if this - * parameter is used. - */ - private @JsonProperty("top_logprobs") Integer topLogprobs; - - /** - * The maximum number of tokens to generate in the chat completion. The total length of input - * tokens and generated tokens is limited by the model's context length. - * - *

Model-specific usage: - * - *

    - *
  • Use for non-reasoning models (e.g., gpt-4o, gpt-3.5-turbo) - *
  • Cannot be used with reasoning models (e.g., o1, o3, o4-mini series) - *
- * - *

Mutual exclusivity: This parameter cannot be used together with {@link - * #maxCompletionTokens}. Setting both will result in an API error. - */ - private @JsonProperty("max_tokens") Integer maxTokens; - - /** - * An upper bound for the number of tokens that can be generated for a completion, including - * visible output tokens and reasoning tokens. - * - *

Model-specific usage: - * - *

    - *
  • Required for reasoning models (e.g., o1, o3, o4-mini series) - *
  • Cannot be used with non-reasoning models (e.g., gpt-4o, gpt-3.5-turbo) - *
- * - *

Mutual exclusivity: This parameter cannot be used together with {@link - * #maxTokens}. Setting both will result in an API error. - */ - private @JsonProperty("max_completion_tokens") Integer maxCompletionTokens; - - /** - * How many chat completion choices to generate for each input message. Note that you will be - * charged based on the number of generated tokens across all of the choices. Keep n as 1 to - * minimize costs. - */ - private @JsonProperty("n") Integer n; - - /** - * Output types that you would like the model to generate for this request. Most models are - * capable of generating text, which is the default. The gpt-4o-audio-preview model can also be - * used to generate audio. To request that this model generate both text and audio responses, you - * can use: ["text", "audio"]. Note that the audio modality is only available for the - * gpt-4o-audio-preview model and is not supported for streaming completions. - */ - private @JsonProperty("modalities") List outputModalities; - - /** - * Audio parameters for the audio generation. Required when audio output is requested with - * modalities: ["audio"] Note: that the audio modality is only available for the - * gpt-4o-audio-preview model and is not supported for streaming completions. - */ - private @JsonProperty("audio") AudioParameters outputAudio; - - /** - * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear - * in the text so far, increasing the model's likelihood to talk about new topics. - */ - private @JsonProperty("presence_penalty") Double presencePenalty; - - /** - * An object specifying the format that the model must output. Setting to { "type": "json_object" - * } enables JSON mode, which guarantees the message the model generates is valid JSON. - */ - private @JsonProperty("response_format") ResponseFormat responseFormat; - - /** - * Options for streaming response. Included in the API only if streaming-mode completion is - * requested. - */ - private @JsonProperty("stream_options") StreamOptions streamOptions; - - /** - * This feature is in Beta. If specified, our system will make a best effort to sample - * deterministically, such that repeated requests with the same seed and parameters should return - * the same result. Determinism is not guaranteed, and you should refer to the system_fingerprint - * response parameter to monitor changes in the backend. - */ - private @JsonProperty("seed") Integer seed; - - /** Up to 4 sequences where the API will stop generating further tokens. */ - private @JsonProperty("stop") List stop; - - /** - * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output - * more random, while lower values like 0.2 will make it more focused and deterministic. We - * generally recommend altering this or top_p but not both. - */ - private @JsonProperty("temperature") Double temperature; - - /** - * An alternative to sampling with temperature, called nucleus sampling, where the model considers - * the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising - * the top 10% probability mass are considered. We generally recommend altering this or - * temperature but not both. - */ - private @JsonProperty("top_p") Double topP; - - /** - * A list of tools the model may call. Currently, only functions are supported as a tool. Use this - * to provide a list of functions the model may generate JSON inputs for. - */ - private @JsonProperty("tools") List tools; - - /** - * Controls which (if any) function is called by the model. none means the model will not call a - * function and instead generates a message. auto means the model can pick between generating a - * message or calling a function. Specifying a particular function via {"type: "function", - * "function": {"name": "my_function"}} forces the model to call that function. none is the - * default when no functions are present. auto is the default if functions are present. Use the - * {@link ToolChoiceBuilder} to create a tool choice object. - */ - private @JsonProperty("tool_choice") Object toolChoice; - - /** - * A unique identifier representing your end-user, which can help OpenAI to monitor and detect - * abuse. - */ - private @JsonProperty("user") String user; - - /** - * Whether to enable parallel - * function calling during tool use. Defaults to true. - */ - private @JsonProperty("parallel_tool_calls") Boolean parallelToolCalls; - - /** - * Whether to store the output of this chat completion request for use in our model distillation or evals products. - */ - private @JsonProperty("store") Boolean store; - - /** - * Developer-defined tags and values used for filtering completions in the dashboard. - */ - private @JsonProperty("metadata") Map metadata; - - /** - * Constrains effort on reasoning for reasoning models. Currently supported values are low, - * medium, and high. Reducing reasoning effort can result in faster responses and fewer tokens - * used on reasoning in a response. Optional. Defaults to medium. Only for 'o1' models. - */ - private @JsonProperty("reasoning_effort") String reasoningEffort; - - /** - * verbosity: string or null Optional - Defaults to medium Constrains the verbosity of the model's - * response. Lower values will result in more concise responses, while higher values will result - * in more verbose responses. Currently supported values are low, medium, and high. If specified, - * the model will use web search to find relevant information to answer the user's question. - */ - private @JsonProperty("verbosity") String verbosity; - - /** This tool searches the web for relevant results to use in a response. */ - private @JsonProperty("web_search_options") WebSearchOptions webSearchOptions; - - /** - * Specifies the processing - * type used for serving the request. - */ - private @JsonProperty("service_tier") String serviceTier; - - /** - * Collection of {@link ToolCallback}s to be used for tool calling in the chat completion - * requests. - */ - @JsonIgnore private List toolCallbacks = new ArrayList<>(); - - /** - * Collection of tool names to be resolved at runtime and used for tool calling in the chat - * completion requests. - */ - @JsonIgnore private Set toolNames = new HashSet<>(); - - /** Whether to enable the tool execution lifecycle internally in ChatModel. */ - @JsonIgnore private Boolean internalToolExecutionEnabled; - - /** Optional HTTP headers to be added to the chat completion request. */ - @JsonIgnore private Map httpHeaders = new HashMap<>(); - - @JsonIgnore private Map toolContext = new HashMap<>(); - - private @JsonProperty("extra_body") Map extraBody; - - // @formatter:on + /** + * ID of the model to use. + */ + private @JsonProperty("model") String model; + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing + * frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + */ + private @JsonProperty("frequency_penalty") Double frequencyPenalty; + /** + * Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object + * that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. + * Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will + * vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 + * or 100 should result in a ban or exclusive selection of the relevant token. + */ + private @JsonProperty("logit_bias") Map logitBias; + /** + * Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities + * of each output token returned in the 'content' of 'message'. + */ + private @JsonProperty("logprobs") Boolean logprobs; + /** + * An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, + * each with an associated log probability. 'logprobs' must be set to 'true' if this parameter is used. + */ + private @JsonProperty("top_logprobs") Integer topLogprobs; + /** + * The maximum number of tokens to generate in the chat completion. + * The total length of input tokens and generated tokens is limited by the model's context length. + * + *

Model-specific usage:

+ *
    + *
  • Use for non-reasoning models (e.g., gpt-4o, gpt-3.5-turbo)
  • + *
  • Cannot be used with reasoning models (e.g., o1, o3, o4-mini series)
  • + *
+ * + *

Mutual exclusivity: This parameter cannot be used together with + * {@link #maxCompletionTokens}. Setting both will result in an API error.

+ */ + private @JsonProperty("max_tokens") Integer maxTokens; + /** + * An upper bound for the number of tokens that can be generated for a completion, + * including visible output tokens and reasoning tokens. + * + *

Model-specific usage:

+ *
    + *
  • Required for reasoning models (e.g., o1, o3, o4-mini series)
  • + *
  • Cannot be used with non-reasoning models (e.g., gpt-4o, gpt-3.5-turbo)
  • + *
+ * + *

Mutual exclusivity: This parameter cannot be used together with + * {@link #maxTokens}. Setting both will result in an API error.

+ */ + private @JsonProperty("max_completion_tokens") Integer maxCompletionTokens; + /** + * How many chat completion choices to generate for each input message. Note that you will be charged based + * on the number of generated tokens across all of the choices. Keep n as 1 to minimize costs. + */ + private @JsonProperty("n") Integer n; + + /** + * Output types that you would like the model to generate for this request. + * Most models are capable of generating text, which is the default. + * The gpt-4o-audio-preview model can also be used to generate audio. + * To request that this model generate both text and audio responses, + * you can use: ["text", "audio"]. + * Note that the audio modality is only available for the gpt-4o-audio-preview model + * and is not supported for streaming completions. + */ + private @JsonProperty("modalities") List outputModalities; + + /** + * Audio parameters for the audio generation. Required when audio output is requested with + * modalities: ["audio"] + * Note: that the audio modality is only available for the gpt-4o-audio-preview model + * and is not supported for streaming completions. + * + */ + private @JsonProperty("audio") AudioParameters outputAudio; + + /** + * Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they + * appear in the text so far, increasing the model's likelihood to talk about new topics. + */ + private @JsonProperty("presence_penalty") Double presencePenalty; + /** + * An object specifying the format that the model must output. Setting to { "type": + * "json_object" } enables JSON mode, which guarantees the message the model generates is valid JSON. + */ + private @JsonProperty("response_format") ResponseFormat responseFormat; + /** + * Options for streaming response. Included in the API only if streaming-mode completion is requested. + */ + private @JsonProperty("stream_options") StreamOptions streamOptions; + /** + * This feature is in Beta. If specified, our system will make a best effort to sample + * deterministically, such that repeated requests with the same seed and parameters should return the same result. + * Determinism is not guaranteed, and you should refer to the system_fingerprint response parameter to monitor + * changes in the backend. + */ + private @JsonProperty("seed") Integer seed; + /** + * Up to 4 sequences where the API will stop generating further tokens. + */ + private @JsonProperty("stop") List stop; + /** + * What sampling temperature to use, between 0 and 1. Higher values like 0.8 will make the output + * more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend + * altering this or top_p but not both. + */ + private @JsonProperty("temperature") Double temperature; + /** + * An alternative to sampling with temperature, called nucleus sampling, where the model considers the + * results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% + * probability mass are considered. We generally recommend altering this or temperature but not both. + */ + private @JsonProperty("top_p") Double topP; + /** + * A list of tools the model may call. Currently, only functions are supported as a tool. Use this to + * provide a list of functions the model may generate JSON inputs for. + */ + private @JsonProperty("tools") List tools; + /** + * Controls which (if any) function is called by the model. none means the model will not call a + * function and instead generates a message. auto means the model can pick between generating a message or calling a + * function. Specifying a particular function via {"type: "function", "function": {"name": "my_function"}} forces + * the model to call that function. none is the default when no functions are present. auto is the default if + * functions are present. Use the {@link ToolChoiceBuilder} to create a tool choice object. + */ + private @JsonProperty("tool_choice") Object toolChoice; + /** + * A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. + */ + private @JsonProperty("user") String user; + /** + * Whether to enable parallel function calling during tool use. + * Defaults to true. + */ + private @JsonProperty("parallel_tool_calls") Boolean parallelToolCalls; + /** + * Whether to store the output of this chat completion request for use in our model distillation or evals products. + */ + private @JsonProperty("store") Boolean store; + + /** + * Developer-defined tags and values used for filtering completions in the dashboard. + */ + private @JsonProperty("metadata") Map metadata; + + /** + * Constrains effort on reasoning for reasoning models. Currently supported values are low, medium, and high. + * Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response. + * Optional. Defaults to medium. + * Only for 'o1' models. + */ + private @JsonProperty("reasoning_effort") String reasoningEffort; + + /** + * verbosity: string or null + * Optional - Defaults to medium + * Constrains the verbosity of the model's response. Lower values will result in more concise responses, while higher values will result in more verbose responses. + * Currently supported values are low, medium, and high. + * If specified, the model will use web search to find relevant information to answer the user's question. + */ + private @JsonProperty("verbosity") String verbosity; + + /** + * This tool searches the web for relevant results to use in a response. + */ + private @JsonProperty("web_search_options") WebSearchOptions webSearchOptions; + + /** + * Specifies the processing type used for serving the request. + */ + private @JsonProperty("service_tier") String serviceTier; + + /** + * Collection of {@link ToolCallback}s to be used for tool calling in the chat completion requests. + */ + @JsonIgnore + private List toolCallbacks = new ArrayList<>(); + + /** + * Collection of tool names to be resolved at runtime and used for tool calling in the chat completion requests. + */ + @JsonIgnore + private Set toolNames = new HashSet<>(); + + /** + * Whether to enable the tool execution lifecycle internally in ChatModel. + */ + @JsonIgnore + private Boolean internalToolExecutionEnabled; + + /** + * Optional HTTP headers to be added to the chat completion request. + */ + @JsonIgnore + private Map httpHeaders = new HashMap<>(); + + @JsonIgnore + private Map toolContext = new HashMap<>(); + + private @JsonProperty("extra_body") Map extraBody; + + // @formatter:on public static Builder builder() { return new Builder(); @@ -748,17 +735,19 @@ public Builder topLogprobs(Integer topLogprobs) { * *

* Model-specific usage: - * + *

*
    - *
  • Use for non-reasoning models (e.g., gpt-4o, gpt-3.5-turbo) + *
  • Use for non-reasoning models (e.g., gpt-4o, + * gpt-3.5-turbo)
  • *
  • Cannot be used with reasoning models (e.g., o1, o3, - * o4-mini series) + * o4-mini series)
  • *
* *

* Mutual exclusivity: This parameter cannot be used together * with {@link #maxCompletionTokens(Integer)}. If both are set, the last one set * will be used and the other will be cleared with a warning. + *

* @param maxTokens the maximum number of tokens to generate, or null to unset * @return this builder instance */ @@ -780,18 +769,19 @@ public Builder maxTokens(Integer maxTokens) { * *

* Model-specific usage: - * + *

*
    *
  • Required for reasoning models (e.g., o1, o3, o4-mini - * series) + * series)
  • *
  • Cannot be used with non-reasoning models (e.g., gpt-4o, - * gpt-3.5-turbo) + * gpt-3.5-turbo)
  • *
* *

* Mutual exclusivity: This parameter cannot be used together * with {@link #maxTokens(Integer)}. If both are set, the last one set will be * used and the other will be cleared with a warning. + *

* @param maxCompletionTokens the maximum number of completion tokens to generate, * or null to unset * @return this builder instance diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 0dea0e6e291..9a654eb449b 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -16,6 +16,13 @@ package org.springframework.ai.openai.api; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import java.util.function.Predicate; +import java.util.stream.Collectors; + import com.fasterxml.jackson.annotation.JsonFormat; import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; @@ -26,12 +33,9 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.annotation.JsonDeserialize; import com.fasterxml.jackson.databind.node.ObjectNode; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Consumer; -import java.util.function.Predicate; -import java.util.stream.Collectors; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + import org.springframework.ai.model.ApiKey; import org.springframework.ai.model.ChatModelDescription; import org.springframework.ai.model.ModelOptionsUtils; @@ -50,8 +54,6 @@ import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; /** * Single class implementation of the @@ -196,7 +198,7 @@ public ResponseEntity chatCompletionEntity(ChatCompletionRequest headers.addAll(additionalHttpHeader); addDefaultHeadersIfMissing(headers); }) - .body(dynamicRequestBody) + .body(dynamicRequestBody) .retrieve() .toEntity(ChatCompletion.class); // @formatter:on @@ -260,15 +262,13 @@ public Flux chatCompletionStream(ChatCompletionRequest chat throw new RuntimeException(e); } Object dynamicBody = createDynamicRequestBody(chatRequest); - // @formatter:off - return this.webClient - .post() - .uri(this.completionsPath) - .headers( - headers -> { - headers.addAll(additionalHttpHeader); - addDefaultHeadersIfMissing(headers); - }) // @formatter:on + // @formatter:off + return this.webClient.post() + .uri(this.completionsPath) + .headers(headers -> { + headers.addAll(additionalHttpHeader); + addDefaultHeadersIfMissing(headers); + }) // @formatter:on .bodyValue(dynamicBody) .retrieve() .bodyToFlux(String.class) @@ -1123,38 +1123,37 @@ public enum OutputModality { * @param verbosity Controls the verbosity of the model's response. */ @JsonInclude(Include.NON_NULL) - public record ChatCompletionRequest( // @formatter:off - @JsonProperty("messages") List messages, - @JsonProperty("model") String model, - @JsonProperty("store") Boolean store, - @JsonProperty("metadata") Map metadata, - @JsonProperty("frequency_penalty") Double frequencyPenalty, - @JsonProperty("logit_bias") Map logitBias, - @JsonProperty("logprobs") Boolean logprobs, - @JsonProperty("top_logprobs") Integer topLogprobs, - @JsonProperty("max_tokens") Integer maxTokens, // original field for specifying token usage. - @JsonProperty("max_completion_tokens") - Integer maxCompletionTokens, // new field for gpt-o1 and other reasoning models - @JsonProperty("n") Integer n, - @JsonProperty("modalities") List outputModalities, - @JsonProperty("audio") AudioParameters audioParameters, - @JsonProperty("presence_penalty") Double presencePenalty, - @JsonProperty("response_format") ResponseFormat responseFormat, - @JsonProperty("seed") Integer seed, - @JsonProperty("service_tier") String serviceTier, - @JsonProperty("stop") List stop, - @JsonProperty("stream") Boolean stream, - @JsonProperty("stream_options") StreamOptions streamOptions, - @JsonProperty("temperature") Double temperature, - @JsonProperty("top_p") Double topP, - @JsonProperty("tools") List tools, - @JsonProperty("tool_choice") Object toolChoice, - @JsonProperty("parallel_tool_calls") Boolean parallelToolCalls, - @JsonProperty("user") String user, - @JsonProperty("reasoning_effort") String reasoningEffort, - @JsonProperty("web_search_options") WebSearchOptions webSearchOptions, - @JsonProperty("verbosity") String verbosity, - @JsonProperty("extra_body") Map extraBody) { + public record ChatCompletionRequest(// @formatter:off + @JsonProperty("messages") List messages, + @JsonProperty("model") String model, + @JsonProperty("store") Boolean store, + @JsonProperty("metadata") Map metadata, + @JsonProperty("frequency_penalty") Double frequencyPenalty, + @JsonProperty("logit_bias") Map logitBias, + @JsonProperty("logprobs") Boolean logprobs, + @JsonProperty("top_logprobs") Integer topLogprobs, + @JsonProperty("max_tokens") Integer maxTokens, // original field for specifying token usage. + @JsonProperty("max_completion_tokens") Integer maxCompletionTokens, // new field for gpt-o1 and other reasoning models + @JsonProperty("n") Integer n, + @JsonProperty("modalities") List outputModalities, + @JsonProperty("audio") AudioParameters audioParameters, + @JsonProperty("presence_penalty") Double presencePenalty, + @JsonProperty("response_format") ResponseFormat responseFormat, + @JsonProperty("seed") Integer seed, + @JsonProperty("service_tier") String serviceTier, + @JsonProperty("stop") List stop, + @JsonProperty("stream") Boolean stream, + @JsonProperty("stream_options") StreamOptions streamOptions, + @JsonProperty("temperature") Double temperature, + @JsonProperty("top_p") Double topP, + @JsonProperty("tools") List tools, + @JsonProperty("tool_choice") Object toolChoice, + @JsonProperty("parallel_tool_calls") Boolean parallelToolCalls, + @JsonProperty("user") String user, + @JsonProperty("reasoning_effort") String reasoningEffort, + @JsonProperty("web_search_options") WebSearchOptions webSearchOptions, + @JsonProperty("verbosity") String verbosity, + @JsonProperty("extra_body") Map extraBody) { /** * Shortcut constructor for a chat completion request with the given messages, model and temperature. @@ -1164,37 +1163,9 @@ public record ChatCompletionRequest( // @formatter:off * @param temperature What sampling temperature to use, between 0 and 1. */ public ChatCompletionRequest(List messages, String model, Double temperature) { - this( - messages, - model, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - false, - null, - temperature, - null, - null, - null, - null, - null, - null, - null, - null, - null); + this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, + null, null, null, false, null, temperature, null, + null, null, null, null, null, null, null,null); } /** @@ -1205,37 +1176,10 @@ public ChatCompletionRequest(List messages, String model, * @param audio Parameters for audio output. Required when audio output is requested with outputModalities: ["audio"]. */ public ChatCompletionRequest(List messages, String model, AudioParameters audio, boolean stream) { - this( - messages, - model, - null, - null, - null, - null, - null, - null, - null, - null, - null, - List.of(OutputModality.AUDIO, OutputModality.TEXT), - audio, - null, - null, - null, - null, - null, - stream, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null); + this(messages, model, null, null, null, null, null, null, + null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null, + null, null, null, stream, null, null, null, + null, null, null, null, null, null, null,null); } /** @@ -1248,37 +1192,9 @@ public ChatCompletionRequest(List messages, String model, * as they become available, with the stream terminated by a data: [DONE] message. */ public ChatCompletionRequest(List messages, String model, Double temperature, boolean stream) { - this( - messages, - model, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - stream, - null, - temperature, - null, - null, - null, - null, - null, - null, - null, - null, - null); + this(messages, model, null, null, null, null, null, null, null, null, null, + null, null, null, null, null, null, null, stream, null, temperature, null, + null, null, null, null, null, null, null,null); } /** @@ -1292,37 +1208,9 @@ public ChatCompletionRequest(List messages, String model, */ public ChatCompletionRequest(List messages, String model, List tools, Object toolChoice) { - this( - messages, - model, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - null, - false, - null, - 0.8, - null, - tools, - toolChoice, - null, - null, - null, - null, - null, - null); + this(messages, model, null, null, null, null, null, null, null, null, null, + null, null, null, null, null, null, null, false, null, 0.8, null, + tools, toolChoice, null, null, null, null, null,null); } /** @@ -1333,10 +1221,9 @@ public ChatCompletionRequest(List messages, String model, * as they become available, with the stream terminated by a data: [DONE] message. */ public ChatCompletionRequest(List messages, Boolean stream) { - this( - messages, null, null, null, null, null, null, null, null, null, null, null, null, null, - null, null, null, null, stream, null, null, null, null, null, null, null, null, null, - null, null); + this(messages, null, null, null, null, null, null, null, null, null, null, null, null, null, + null, null, null, null, stream, null, null, null, null, null, null, null, null, null, + null, null); } /** @@ -1346,37 +1233,10 @@ public ChatCompletionRequest(List messages, Boolean strea * @return A new {@link ChatCompletionRequest} with the specified stream options. */ public ChatCompletionRequest streamOptions(StreamOptions streamOptions) { - return new ChatCompletionRequest( - this.messages, - this.model, - this.store, - this.metadata, - this.frequencyPenalty, - this.logitBias, - this.logprobs, - this.topLogprobs, - this.maxTokens, - this.maxCompletionTokens, - this.n, - this.outputModalities, - this.audioParameters, - this.presencePenalty, - this.responseFormat, - this.seed, - this.serviceTier, - this.stop, - this.stream, - streamOptions, - this.temperature, - this.topP, - this.tools, - this.toolChoice, - this.parallelToolCalls, - this.user, - this.reasoningEffort, - this.webSearchOptions, - this.verbosity, - this.extraBody); + return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs, + this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty, + this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP, + this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions, this.verbosity,this.extraBody); } /** @@ -1580,18 +1440,17 @@ public String getValue() { */ @JsonInclude(Include.NON_NULL) @JsonIgnoreProperties(ignoreUnknown = true) - public record ChatCompletionMessage( // @formatter:off - @JsonProperty("content") Object rawContent, - @JsonProperty("role") Role role, - @JsonProperty("name") String name, - @JsonProperty("tool_call_id") String toolCallId, - @JsonProperty("tool_calls") - @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) - List toolCalls, - @JsonProperty("refusal") String refusal, - @JsonProperty("audio") AudioOutput audioOutput, - @JsonProperty("annotations") List annotations, - @JsonProperty("reasoning_content") String reasoningContent) { // @formatter:on + public record ChatCompletionMessage(// @formatter:off + @JsonProperty("content") Object rawContent, + @JsonProperty("role") Role role, + @JsonProperty("name") String name, + @JsonProperty("tool_call_id") String toolCallId, + @JsonProperty("tool_calls") @JsonFormat(with = JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY) List toolCalls, + @JsonProperty("refusal") String refusal, + @JsonProperty("audio") AudioOutput audioOutput, + @JsonProperty("annotations") List annotations, + @JsonProperty("reasoning_content") String reasoningContent + ) { // @formatter:on /** * Create a chat completion message with the given content and role. All other diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java index 0628acb4ab6..463d68e4023 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelper.java @@ -18,6 +18,7 @@ import java.util.ArrayList; import java.util.List; + import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java index 8cca66efe67..01b7635418c 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiBuilderTests.java @@ -16,15 +16,12 @@ package org.springframework.ai.openai.api; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; -import static org.mockito.Mockito.mock; - import java.io.IOException; import java.util.LinkedList; import java.util.List; import java.util.Objects; import java.util.Queue; + import okhttp3.mockwebserver.MockResponse; import okhttp3.mockwebserver.MockWebServer; import okhttp3.mockwebserver.RecordedRequest; @@ -33,6 +30,7 @@ import org.junit.jupiter.api.Nested; import org.junit.jupiter.api.Test; import org.opentest4j.AssertionFailedError; + import org.springframework.ai.model.ApiKey; import org.springframework.ai.model.SimpleApiKey; import org.springframework.http.HttpHeaders; @@ -45,6 +43,10 @@ import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + public class OpenAiApiBuilderTests { private static final ApiKey TEST_API_KEY = new SimpleApiKey("test-api-key"); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java index 12c3bba3924..55c223d818a 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java @@ -16,17 +16,17 @@ package org.springframework.ai.openai.api; -import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatThrownBy; - import java.io.IOException; import java.util.Base64; import java.util.List; + import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.EnumSource; +import reactor.core.publisher.Flux; + import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; @@ -36,7 +36,9 @@ import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList; import org.springframework.core.io.ClassPathResource; import org.springframework.http.ResponseEntity; -import reactor.core.publisher.Flux; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; /** * @author Christian Tzolov diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java index dd3df247faf..6e9967908e5 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiStreamFunctionCallingHelperTest.java @@ -16,14 +16,15 @@ package org.springframework.ai.openai.api; -import static org.assertj.core.api.Assertions.assertThat; - import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.function.Consumer; + import org.junit.jupiter.api.Test; import org.mockito.Mockito; + +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; /** @@ -255,7 +256,7 @@ public void merge_partialFieldsFromEachChunk() { public void isStreamingToolFunctionCall_withMultipleChoicesAndOnlyFirstHasToolCalls() { var toolCall = Mockito.mock(OpenAiApi.ChatCompletionMessage.ToolCall.class); var deltaWithToolCalls = new OpenAiApi.ChatCompletionMessage(null, null, null, null, List.of(toolCall), null, - null, null); + null, null, null); var deltaWithoutToolCalls = new OpenAiApi.ChatCompletionMessage(null, null); var choice1 = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, deltaWithToolCalls, null); @@ -327,7 +328,7 @@ public void edgeCases_emptyStringFields() { @Test public void isStreamingToolFunctionCall_withNullToolCallsList() { - var delta = new OpenAiApi.ChatCompletionMessage(null, null, null, null, null, null, null, null); + var delta = new OpenAiApi.ChatCompletionMessage(null, null, null, null, null, null, null, null, null); var choice = new OpenAiApi.ChatCompletionChunk.ChunkChoice(null, null, delta, null); var chunk = new OpenAiApi.ChatCompletionChunk(null, List.of(choice), null, null, null, null, null, null); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java index e4ddf7b42aa..3ce6e3f24e6 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/tool/OpenAiApiToolFunctionCallIT.java @@ -16,16 +16,16 @@ package org.springframework.ai.openai.api.tool; -import static org.assertj.core.api.Assertions.assertThat; +import java.util.ArrayList; +import java.util.List; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; -import java.util.ArrayList; -import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; + import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; @@ -36,6 +36,8 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest.ToolChoiceBuilder; import org.springframework.http.ResponseEntity; +import static org.assertj.core.api.Assertions.assertThat; + /** * Based on the OpenAI Function Calling tutorial: * https://platform.openai.com/docs/guides/function-calling/parallel-function-calling diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java index 766f6a101b9..e19e82640b2 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiRetryTests.java @@ -16,14 +16,9 @@ package org.springframework.ai.openai.chat; -import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.BDDMockito.given; - import java.util.List; import java.util.Optional; + import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; @@ -32,6 +27,8 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import reactor.core.publisher.Flux; + import org.springframework.ai.audio.transcription.AudioTranscriptionPrompt; import org.springframework.ai.audio.transcription.AudioTranscriptionResponse; import org.springframework.ai.chat.prompt.Prompt; @@ -73,7 +70,12 @@ import org.springframework.retry.RetryContext; import org.springframework.retry.RetryListener; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.BDDMockito.given; /** * @author Christian Tzolov diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java index a35fce716ca..1664782314b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiStreamingFinishReasonTests.java @@ -16,18 +16,16 @@ package org.springframework.ai.openai.chat; -import static org.assertj.core.api.Assertions.assertThat; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.isA; -import static org.mockito.BDDMockito.given; +import java.util.List; import com.fasterxml.jackson.core.JsonProcessingException; import io.micrometer.observation.ObservationRegistry; -import java.util.List; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import reactor.core.publisher.Flux; + import org.springframework.ai.chat.model.ChatResponse; import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.model.ModelOptionsUtils; @@ -43,7 +41,11 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; import org.springframework.ai.retry.RetryUtils; import org.springframework.retry.support.RetryTemplate; -import reactor.core.publisher.Flux; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.isA; +import static org.mockito.BDDMockito.given; /** * Tests for OpenAI streaming responses with various finish_reason scenarios, particularly From bee448032b29c7c452cc429ea1fd6c391dc1ac2a Mon Sep 17 00:00:00 2001 From: SenreySong <25841017+SenreySong@users.noreply.github.com> Date: Thu, 25 Sep 2025 20:30:13 +0800 Subject: [PATCH 5/5] refactor(spring-ai-openai): fix style check Signed-off-by: SenreySong <25841017+SenreySong@users.noreply.github.com> --- .../springframework/ai/openai/OpenAiChatOptions.java | 2 +- .../org/springframework/ai/openai/api/OpenAiApi.java | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java index bdbb308d6ef..c7c2159d988 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatOptions.java @@ -506,7 +506,7 @@ public void setParallelToolCalls(Boolean parallelToolCalls) { } public Map getExtraBody() { - return extraBody; + return this.extraBody; } public void setExtraBody(Map extraBody) { diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 9a654eb449b..2a58a79c2ef 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -1165,7 +1165,7 @@ public record ChatCompletionRequest(// @formatter:off public ChatCompletionRequest(List messages, String model, Double temperature) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, temperature, null, - null, null, null, null, null, null, null,null); + null, null, null, null, null, null, null, null); } /** @@ -1179,7 +1179,7 @@ public ChatCompletionRequest(List messages, String model, this(messages, model, null, null, null, null, null, null, null, null, null, List.of(OutputModality.AUDIO, OutputModality.TEXT), audio, null, null, null, null, null, stream, null, null, null, - null, null, null, null, null, null, null,null); + null, null, null, null, null, null, null, null); } /** @@ -1194,7 +1194,7 @@ public ChatCompletionRequest(List messages, String model, public ChatCompletionRequest(List messages, String model, Double temperature, boolean stream) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, stream, null, temperature, null, - null, null, null, null, null, null, null,null); + null, null, null, null, null, null, null, null); } /** @@ -1210,7 +1210,7 @@ public ChatCompletionRequest(List messages, String model, List tools, Object toolChoice) { this(messages, model, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, null, false, null, 0.8, null, - tools, toolChoice, null, null, null, null, null,null); + tools, toolChoice, null, null, null, null, null, null); } /** @@ -1236,7 +1236,7 @@ public ChatCompletionRequest streamOptions(StreamOptions streamOptions) { return new ChatCompletionRequest(this.messages, this.model, this.store, this.metadata, this.frequencyPenalty, this.logitBias, this.logprobs, this.topLogprobs, this.maxTokens, this.maxCompletionTokens, this.n, this.outputModalities, this.audioParameters, this.presencePenalty, this.responseFormat, this.seed, this.serviceTier, this.stop, this.stream, streamOptions, this.temperature, this.topP, - this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions, this.verbosity,this.extraBody); + this.tools, this.toolChoice, this.parallelToolCalls, this.user, this.reasoningEffort, this.webSearchOptions, this.verbosity, this.extraBody); } /**