-
Notifications
You must be signed in to change notification settings - Fork 2k
Openai api sse review #172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,29 +16,30 @@ | |
|
|
||
| package org.springframework.ai.openai.api; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.function.Consumer; | ||
|
|
||
| import com.fasterxml.jackson.annotation.JsonInclude; | ||
| import com.fasterxml.jackson.annotation.JsonInclude.Include; | ||
| import com.fasterxml.jackson.annotation.JsonProperty; | ||
| import com.fasterxml.jackson.core.type.TypeReference; | ||
| import com.fasterxml.jackson.databind.ObjectMapper; | ||
| import reactor.core.publisher.Flux; | ||
| import reactor.core.publisher.Mono; | ||
|
|
||
| import org.springframework.core.ParameterizedTypeReference; | ||
| import org.springframework.http.HttpHeaders; | ||
| import org.springframework.http.MediaType; | ||
| import org.springframework.http.ResponseEntity; | ||
| import org.springframework.http.client.ClientHttpResponse; | ||
| import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; | ||
| import org.springframework.util.Assert; | ||
| import org.springframework.util.CollectionUtils; | ||
| import org.springframework.web.client.ResponseErrorHandler; | ||
| import org.springframework.web.client.RestClient; | ||
| import org.springframework.web.reactive.function.client.WebClient; | ||
| import reactor.core.publisher.Flux; | ||
| import reactor.core.publisher.Mono; | ||
|
|
||
| import java.io.IOException; | ||
| import java.util.List; | ||
| import java.util.Map; | ||
| import java.util.function.Consumer; | ||
| import java.util.function.Predicate; | ||
|
|
||
| // @formatter:off | ||
| /** | ||
|
|
@@ -51,7 +52,7 @@ public class OpenAiApi { | |
|
|
||
| private static final String DEFAULT_BASE_URL = "https://api.openai.com"; | ||
| private static final String DEFAULT_EMBEDDING_MODEL = "text-embedding-ada-002"; | ||
| private static final String SSE_DONE = "[DONE]"; | ||
| private static final Predicate<String> SSE_DONE_PREDICATE = "[DONE]"::equals; | ||
|
|
||
| private final RestClient restClient; | ||
| private final WebClient webClient; | ||
|
|
@@ -74,8 +75,8 @@ public OpenAiApi(String openAiToken) { | |
| * @param restClientBuilder RestClient builder. | ||
| */ | ||
| public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) { | ||
|
|
||
| this.objectMapper = new ObjectMapper(); | ||
| // Use the same ObjectMapper for WebClient's response as the one used in RestClient. | ||
| this.objectMapper = Jackson2ObjectMapperBuilder.json().build(); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is interesting, need to dig a bit more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I encountered an error during code review. It appears that LogProbs were added at some point. After checking OpenAI's API specification at https://platform.openai.com/docs/api-reference/chat/object, I noticed the presence of LogProbs. However, in the response for the ChatCompletion - Choice record when the stream is false, LogProbs are not included. To address this issue, I realized that a similar approach is needed for WebClient, similar to the ObjectMapper in RestClient, to ignore unknown fields. Therefore, I added the same configuration in the commit. Instead of adding LogProbs to ChatCompletion - Choice and ChatCompletionChunk record, considering the continuous changes and potential additions or deprecations in OpenAI's spec, I believe it's more beneficial for developers using Spring AI to configure the ObjectMapper to ignore unknown fields. This way, they won't have to deal with too many specific options. Hence, I implemented the ObjectMapper configuration in the current commit. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already added the LogProbs in another commit. |
||
|
|
||
| Consumer<HttpHeaders> jsonContentHeaders = headers -> { | ||
| headers.setBearerAuth(openAiToken); | ||
|
|
@@ -533,7 +534,15 @@ public record ChatCompletionChunk( | |
| public record ChunkChoice( | ||
| @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, | ||
| @JsonProperty("index") Integer index, | ||
| @JsonProperty("delta") ChatCompletionMessage delta) { | ||
| @JsonProperty("delta") Delta delta | ||
| ) { | ||
| // The delta used in ChatCompletionChunk is not identical to ChatCompletion's ChatCompletionMessage, so a Delta record needs to be defined separately. | ||
| public record Delta( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm afraid that this is incorrect. The message format is the same in both the chat completion object and the chat completion chunk object and it includes the tool_calls hierarchy as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I misunderstood. Thank you for letting me know. |
||
|
|
||
| @JsonProperty("role") ChatCompletionMessage.Role role, | ||
|
|
||
| @JsonProperty("content") String content) { | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -571,10 +580,10 @@ public Flux<ChatCompletionChunk> chatCompletionStream(ChatCompletionRequest chat | |
| .body(Mono.just(chatRequest), ChatCompletionRequest.class) | ||
| .retrieve() | ||
| .bodyToFlux(String.class) | ||
| // cancels the flux stream after the SSE_DONE is received. | ||
| .takeUntil(content -> content.contains(SSE_DONE)) | ||
| // filters out the SSE_DONE message. | ||
| .filter(content -> !content.contains(SSE_DONE)) | ||
| // cancels the flux stream after the "[DONE]" is received. | ||
| .takeUntil(SSE_DONE_PREDICATE) | ||
| // filters out the "[DONE]" message. | ||
| .filter(SSE_DONE_PREDICATE.negate()) | ||
| .map(content -> parseJson(content, ChatCompletionChunk.class)); | ||
| } | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I presume that the Predicate is just a code style improvement, not due to issues with the existing conditions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The documentation at https://platform.openai.com/docs/api-reference/chat/create explicitly states that it ends with '[DONE]', so in my opinion, using 'equals' would be clearer than 'contains.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the predicate.