diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index de3460d31b7..b43069e9a79 100644 --- a/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -16,29 +16,30 @@ package org.springframework.ai.openai.api; -import java.io.IOException; -import java.util.List; -import java.util.Map; -import java.util.function.Consumer; - import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.http.client.ClientHttpResponse; +import org.springframework.http.converter.json.Jackson2ObjectMapperBuilder; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.web.client.ResponseErrorHandler; import org.springframework.web.client.RestClient; import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.function.Consumer; +import java.util.function.Predicate; // @formatter:off /** @@ -51,7 +52,7 @@ public class OpenAiApi { private static final String DEFAULT_BASE_URL = "https://api.openai.com"; private static final String DEFAULT_EMBEDDING_MODEL = "text-embedding-ada-002"; - private static final String SSE_DONE = "[DONE]"; + private static final Predicate SSE_DONE_PREDICATE = "[DONE]"::equals; private final RestClient restClient; private final WebClient webClient; @@ -74,8 +75,8 @@ public OpenAiApi(String openAiToken) { * @param restClientBuilder RestClient builder. */ public OpenAiApi(String baseUrl, String openAiToken, RestClient.Builder restClientBuilder) { - - this.objectMapper = new ObjectMapper(); + // Use the same ObjectMapper for WebClient's response as the one used in RestClient. + this.objectMapper = Jackson2ObjectMapperBuilder.json().build(); Consumer jsonContentHeaders = headers -> { headers.setBearerAuth(openAiToken); @@ -533,7 +534,15 @@ public record ChatCompletionChunk( public record ChunkChoice( @JsonProperty("finish_reason") ChatCompletionFinishReason finishReason, @JsonProperty("index") Integer index, - @JsonProperty("delta") ChatCompletionMessage delta) { + @JsonProperty("delta") Delta delta + ) { + // The delta used in ChatCompletionChunk is not identical to ChatCompletion's ChatCompletionMessage, so a Delta record needs to be defined separately. + public record Delta( + + @JsonProperty("role") ChatCompletionMessage.Role role, + + @JsonProperty("content") String content) { + } } } @@ -571,10 +580,10 @@ public Flux chatCompletionStream(ChatCompletionRequest chat .body(Mono.just(chatRequest), ChatCompletionRequest.class) .retrieve() .bodyToFlux(String.class) - // cancels the flux stream after the SSE_DONE is received. - .takeUntil(content -> content.contains(SSE_DONE)) - // filters out the SSE_DONE message. - .filter(content -> !content.contains(SSE_DONE)) + // cancels the flux stream after the "[DONE]" is received. + .takeUntil(SSE_DONE_PREDICATE) + // filters out the "[DONE]" message. + .filter(SSE_DONE_PREDICATE.negate()) .map(content -> parseJson(content, ChatCompletionChunk.class)); }