Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,29 +16,30 @@

package org.springframework.ai.openai.api;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;

import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpHeaders;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.client.ResponseErrorHandler;
import org.springframework.web.client.RestClient;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.io.IOException;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.function.Predicate;

// @formatter:off
/**
Expand All @@ -51,7 +52,7 @@ public class OpenAiApi {

private static final String DEFAULT_BASE_URL = "https://api.openai.com";
private static final String DEFAULT_EMBEDDING_MODEL = "text-embedding-ada-002";
private static final String SSE_DONE = "[DONE]";
private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I presume that the Predicate is just a code style improvement, not due to issues with the existing conditions?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation at https://platform.openai.com/docs/api-reference/chat/create explicitly states that it ends with '[DONE]', so in my opinion, using 'equals' would be clearer than 'contains.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the predicate.


private final RestClient restClient;
private final WebClient webClient;
Expand All @@ -74,8 +75,8 @@ public OpenAiApi(String openAiToken) {
* @param restClientBuilder RestClient builder.
*/
public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) {

this.objectMapper = new ObjectMapper();
// Use the same ObjectMapper for WebClient's response as the one used in RestClient.
this.objectMapper = Jackson2ObjectMapperBuilder.json().build();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is interesting, need to dig a bit more.
The recent streaming response parsing failure was due to the unexpected (for me) appearance of the LogProbs in the chunk response. Not sure if i missed it or OpenAI introduced it silently. Hope it is the former.
I deliberately didn't configure the objectmapper to ignore unknown fields so i can catch changes. But didn't expect they will come without announcement and also didn't consider that the webclient and restclient converters are likely ignoring such fields.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I encountered an error during code review. It appears that LogProbs were added at some point. After checking OpenAI's API specification at https://platform.openai.com/docs/api-reference/chat/object, I noticed the presence of LogProbs. However, in the response for the ChatCompletion - Choice record when the stream is false, LogProbs are not included. To address this issue, I realized that a similar approach is needed for WebClient, similar to the ObjectMapper in RestClient, to ignore unknown fields. Therefore, I added the same configuration in the commit.

Instead of adding LogProbs to ChatCompletion - Choice and ChatCompletionChunk record, considering the continuous changes and potential additions or deprecations in OpenAI's spec, I believe it's more beneficial for developers using Spring AI to configure the ObjectMapper to ignore unknown fields. This way, they won't have to deal with too many specific options. Hence, I implemented the ObjectMapper configuration in the current commit.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already added the LogProbs in another commit.
But with this PR I'll configure the object mapper to tolerate unknown properties.
Would live out the Jackson2ObjectMapperBuilder for now.


Consumer<HttpHeaders> jsonContentHeaders = headers -> {
headers.setBearerAuth(openAiToken);
Expand Down Expand Up @@ -533,7 +534,15 @@ public record ChatCompletionChunk(
public record ChunkChoice(
@JsonProperty("finish_reason") ChatCompletionFinishReason finishReason,
@JsonProperty("index") Integer index,
@JsonProperty("delta") ChatCompletionMessage delta) {
@JsonProperty("delta") Delta delta
) {
// The delta used in ChatCompletionChunk is not identical to ChatCompletion's ChatCompletionMessage, so a Delta record needs to be defined separately.
public record Delta(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm afraid that this is incorrect. The message format is the same in both the chat completion object and the chat completion chunk object and it includes the tool_calls hierarchy as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I misunderstood. Thank you for letting me know.


@JsonProperty("role") ChatCompletionMessage.Role role,

@JsonProperty("content") String content) {
}
}
}

Expand Down Expand Up @@ -571,10 +580,10 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat
.body(Mono.just(chatRequest), ChatCompletionRequest.class)
.retrieve()
.bodyToFlux(String.class)
// cancels the flux stream after the SSE_DONE is received.
.takeUntil(content -> content.contains(SSE_DONE))
// filters out the SSE_DONE message.
.filter(content -> !content.contains(SSE_DONE))
// cancels the flux stream after the "[DONE]" is received.
.takeUntil(SSE_DONE_PREDICATE)
// filters out the "[DONE]" message.
.filter(SSE_DONE_PREDICATE.negate())
.map(content -> parseJson(content, ChatCompletionChunk.class));
}

Expand Down