From 620aff66cd5076ade8a097aaaccf27b017390ebd Mon Sep 17 00:00:00 2001 From: Jemin Huh Date: Wed, 22 Nov 2023 00:45:17 +0900 Subject: [PATCH 1/3] Add OpenAI Stream Client implementation --- .../ai/openai/client/OpenAiStreamClient.java | 52 +++++++++++++++++++ .../ai/openai/OpenAiTestConfiguration.java | 9 ++++ .../ai/openai/client/ClientIT.java | 28 +++++++++- .../ai/openai/testutils/AbstractIT.java | 4 ++ 4 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiStreamClient.java diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiStreamClient.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiStreamClient.java new file mode 100644 index 00000000000..f6073800912 --- /dev/null +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiStreamClient.java @@ -0,0 +1,52 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.client; + +import com.theokanning.openai.completion.chat.ChatCompletionChunk; +import com.theokanning.openai.completion.chat.ChatCompletionRequest; +import com.theokanning.openai.completion.chat.ChatMessage; +import com.theokanning.openai.service.OpenAiService; +import io.reactivex.Flowable; +import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.prompt.messages.Message; + +import java.util.List; + +public class OpenAiStreamClient extends OpenAiClient { + private final OpenAiService openAiService; + + public OpenAiStreamClient(OpenAiService openAiService) { + super(openAiService); + this.openAiService = openAiService; + } + + public Flowable generateStream(Prompt prompt) { + + List messages = prompt.getMessages(); + + List theoMessages = + messages.stream().map(message -> new ChatMessage(message.getMessageTypeValue(), message.getContent())) + .toList(); + + ChatCompletionRequest chatCompletionRequest = + ChatCompletionRequest.builder().model(getModel()).temperature(getTemperature()).messages(theoMessages) + .stream(true).build(); + + return this.openAiService.streamChatCompletion(chatCompletionRequest); + } + +} diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index 878cc74a6f6..6b15ed07c87 100644 --- a/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -3,6 +3,7 @@ import com.theokanning.openai.service.OpenAiService; import org.springframework.ai.embedding.EmbeddingClient; import org.springframework.ai.openai.client.OpenAiClient; +import org.springframework.ai.openai.client.OpenAiStreamClient; import org.springframework.ai.openai.embedding.OpenAiEmbeddingClient; import org.springframework.boot.SpringBootConfiguration; import org.springframework.context.annotation.Bean; @@ -36,4 +37,12 @@ public EmbeddingClient openAiEmbeddingClient(OpenAiService theoOpenAiService) { return new OpenAiEmbeddingClient(theoOpenAiService); } + + @Bean + public OpenAiStreamClient openAiStreamClient(OpenAiService theoOpenAiService) { + OpenAiStreamClient OpenAiStreamClient = new OpenAiStreamClient(theoOpenAiService); + OpenAiStreamClient.setTemperature(0.3); + return OpenAiStreamClient; + } + } diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIT.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIT.java index cd7dc7d535c..9328e16ddbc 100644 --- a/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIT.java +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIT.java @@ -1,6 +1,8 @@ package org.springframework.ai.openai.client; -import org.junit.jupiter.api.Disabled; +import com.theokanning.openai.completion.chat.ChatCompletionChoice; +import com.theokanning.openai.completion.chat.ChatCompletionChunk; +import com.theokanning.openai.completion.chat.ChatMessage; import org.junit.jupiter.api.Test; import org.springframework.ai.client.AiResponse; import org.springframework.ai.client.Generation; @@ -21,6 +23,7 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import static org.assertj.core.api.Assertions.assertThat; @@ -124,4 +127,27 @@ void beanOutputParserRecords() { assertThat(actorsFilms.movies()).hasSize(5); } + @Test + void beanStreamOutputParserRecords() { + + BeanOutputParser outputParser = new BeanOutputParser<>(ActorsFilmsRecord.class); + + String format = outputParser.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + String generationTextFromStream = openAiStreamClient.generateStream(prompt).toList().blockingGet().stream() + .map(ChatCompletionChunk::getChoices).flatMap(List::stream) + .map(ChatCompletionChoice::getMessage).map(ChatMessage::getContent) + .collect(Collectors.joining()); + + ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); + System.out.println(actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + } diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java index 8cec9f7e0b6..87d00bf30f1 100644 --- a/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java @@ -4,6 +4,7 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.client.AiClient; import org.springframework.ai.client.AiResponse; +import org.springframework.ai.openai.client.OpenAiStreamClient; import org.springframework.ai.prompt.Prompt; import org.springframework.ai.prompt.PromptTemplate; import org.springframework.ai.prompt.messages.Message; @@ -25,6 +26,9 @@ public abstract class AbstractIT { @Autowired protected AiClient openAiClient; + @Autowired + protected OpenAiStreamClient openAiStreamClient; + @Value("classpath:/prompts/eval/qa-evaluator-accurate-answer.st") protected Resource qaEvaluatorAccurateAnswerResource; From 3cabc7ffddb0912e50be3a243b41aa3f313ce804 Mon Sep 17 00:00:00 2001 From: Jemin Huh Date: Sun, 26 Nov 2023 16:25:43 +0900 Subject: [PATCH 2/3] Add the latest Request and Response Classes of OpenAI's Chat Completion API --- .../openai/client/ChatCompletionResponse.java | 43 +++ .../openai/client/ChatCompletionsRequest.java | 279 ++++++++++++++++++ .../ai/openai/client/OpenAiChatMessage.java | 89 ++++++ .../ai/openai/client/OpenAiSseResponse.java | 48 +++ 4 files changed, 459 insertions(+) create mode 100644 spring-ai-openai/src/main/java/org/springframework/ai/openai/client/ChatCompletionResponse.java create mode 100644 spring-ai-openai/src/main/java/org/springframework/ai/openai/client/ChatCompletionsRequest.java create mode 100644 spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiChatMessage.java create mode 100644 spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiSseResponse.java diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/ChatCompletionResponse.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/ChatCompletionResponse.java new file mode 100644 index 00000000000..bedde945942 --- /dev/null +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/ChatCompletionResponse.java @@ -0,0 +1,43 @@ +package org.springframework.ai.openai.client; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Map; + +public record ChatCompletionResponse( + + @JsonProperty("id") + String id, + + @JsonProperty("choices") + List choices, + + @JsonProperty("created") + Integer created, + + @JsonProperty("model") + String model, + + @JsonProperty("system_fingerprint") + String systemFingerprint, + + @JsonProperty("object") + String object, + + @JsonProperty("usage") + Map usage +) { + public record Choice( + + @JsonProperty("finish_reason") + String finishReason, + + @JsonProperty("index") + Integer index, + + @JsonProperty("message") + OpenAiChatMessage message + ) { + } +} \ No newline at end of file diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/ChatCompletionsRequest.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/ChatCompletionsRequest.java new file mode 100644 index 00000000000..835fbd0aea0 --- /dev/null +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/ChatCompletionsRequest.java @@ -0,0 +1,279 @@ +package org.springframework.ai.openai.client; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; + +import java.util.List; +import java.util.Map; + +@JsonDeserialize(builder = ChatCompletionsRequest.Builder.class) +public class ChatCompletionsRequest { + + private final List messages; + private final String model; + private final Integer frequencyPenalty; + private final Map logitBias; + private final Integer maxTokens; + private final Integer n; + private final Integer presencePenalty; + private final ResponseFormat responseFormat; + private final Integer seed; + private final List stop; + private final Boolean stream; + private final Double temperature; + private final Integer topP; + private final List tools; + private final String toolChoice; + private final String user; + + private ChatCompletionsRequest(Builder builder) { + this.messages = builder.messages; + this.model = builder.model; + this.frequencyPenalty = builder.frequencyPenalty; + this.logitBias = builder.logitBias; + this.maxTokens = builder.maxTokens; + this.n = builder.n; + this.presencePenalty = builder.presencePenalty; + this.responseFormat = builder.responseFormat; + this.seed = builder.seed; + this.stop = builder.stop; + this.stream = builder.stream; + this.temperature = builder.temperature; + this.topP = builder.topP; + this.tools = builder.tools; + this.toolChoice = builder.toolChoice; + this.user = builder.user; + } + + public List getMessages() { + return messages; + } + + public String getModel() { + return model; + } + + public Integer getFrequencyPenalty() { + return frequencyPenalty; + } + + public Map getLogitBias() { + return logitBias; + } + + public Integer getMaxTokens() { + return maxTokens; + } + + public Integer getN() { + return n; + } + + public Integer getPresencePenalty() { + return presencePenalty; + } + + public ResponseFormat getResponseFormat() { + return responseFormat; + } + + public Integer getSeed() { + return seed; + } + + public List getStop() { + return stop; + } + + public Boolean getStream() { + return stream; + } + + public Double getTemperature() { + return temperature; + } + + public Integer getTopP() { + return topP; + } + + public List getTools() { + return tools; + } + + public String getToolChoice() { + return toolChoice; + } + + public String getUser() { + return user; + } + + public static class Builder { + + @JsonProperty("messages") + private List messages; + + @JsonProperty("model") + private String model; + + @JsonProperty("frequency_penalty") + private Integer frequencyPenalty; + + @JsonProperty("logit_bias") + private Map logitBias; + + @JsonProperty("max_tokens") + private Integer maxTokens; + + @JsonProperty("n") + private Integer n; + + @JsonProperty("presence_penalty") + private Integer presencePenalty; + + @JsonProperty("response_format") + private ResponseFormat responseFormat; + + @JsonProperty("seed") + private Integer seed; + + @JsonProperty("stop") + private List stop; + + @JsonProperty("stream") + private Boolean stream; + + @JsonProperty("temperature") + private Double temperature; + + @JsonProperty("top_p") + private Integer topP; + + @JsonProperty("tools") + private List tools; + + @JsonProperty("tool_choice") + private String toolChoice; + + @JsonProperty("user") + private String user; + + public Builder messages(List messages) { + this.messages = messages; + return this; + } + + public Builder model(String model) { + this.model = model; + return this; + } + + public Builder frequencyPenalty(Integer frequencyPenalty) { + this.frequencyPenalty = frequencyPenalty; + return this; + } + + public Builder logitBias(Map logitBias) { + this.logitBias = logitBias; + return this; + } + + public Builder maxTokens(Integer maxTokens) { + this.maxTokens = maxTokens; + return this; + } + + public Builder n(Integer n) { + this.n = n; + return this; + } + + public Builder presencePenalty(Integer presencePenalty) { + this.presencePenalty = presencePenalty; + return this; + } + + public Builder responseFormat(ResponseFormat responseFormat) { + this.responseFormat = responseFormat; + return this; + } + + public Builder seed(Integer seed) { + this.seed = seed; + return this; + } + + public Builder stop(List stop) { + this.stop = stop; + return this; + } + + public Builder stream(Boolean stream) { + this.stream = stream; + return this; + } + + public Builder temperature(Double temperature) { + this.temperature = temperature; + return this; + } + + public Builder topP(Integer topP) { + this.topP = topP; + return this; + } + + public Builder tools(List tools) { + this.tools = tools; + return this; + } + + public Builder toolChoice(String toolChoice) { + this.toolChoice = toolChoice; + return this; + } + + public Builder user(String user) { + this.user = user; + return this; + } + + public ChatCompletionsRequest build() { + return new ChatCompletionsRequest(this); + } + } + + public record Function( + + @JsonProperty("name") + String name, + + @JsonProperty("description") + String description, + + @JsonProperty("parameters") + Map parameters, + + @JsonProperty("arguments") + String arguments + ) { + } + + public record ResponseFormat( + + @JsonProperty("type") + String type + ) { + } + + public record Tool( + + @JsonProperty("function") + Function function, + + @JsonProperty("type") + String type + ) { + } +} diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiChatMessage.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiChatMessage.java new file mode 100644 index 00000000000..0ec4cb268d8 --- /dev/null +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiChatMessage.java @@ -0,0 +1,89 @@ +package org.springframework.ai.openai.client; + +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.databind.annotation.JsonDeserialize; + +import java.util.List; + +@JsonDeserialize(builder = OpenAiChatMessage.Builder.class) +public class OpenAiChatMessage { + + private final String role; + private final String name; + private final List toolCalls; + private final String content; + + private OpenAiChatMessage(Builder builder) { + this.role = builder.role; + this.name = builder.name; + this.toolCalls = builder.toolCalls; + this.content = builder.content; + } + + public String getRole() { + return role; + } + + public String getName() { + return name; + } + + public List getToolCalls() { + return toolCalls; + } + + public String getContent() { + return content; + } + + public static class Builder { + + @JsonProperty("role") + private String role; + + @JsonProperty("name") + private String name; + + @JsonProperty("tool_calls") + private List toolCalls; + + @JsonProperty("content") + private String content; + + public Builder role(String role) { + this.role = role; + return this; + } + + public Builder name(String name) { + this.name = name; + return this; + } + + public Builder toolCalls(List toolCalls) { + this.toolCalls = toolCalls; + return this; + } + + public Builder content(String content) { + this.content = content; + return this; + } + + public OpenAiChatMessage build() { + return new OpenAiChatMessage(this); + } + } + public record ToolCall( + + @JsonProperty("function") + ChatCompletionsRequest.Function function, + + @JsonProperty("id") + String id, + + @JsonProperty("type") + String type + ) { + } +} \ No newline at end of file diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiSseResponse.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiSseResponse.java new file mode 100644 index 00000000000..491c85b6374 --- /dev/null +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiSseResponse.java @@ -0,0 +1,48 @@ +package org.springframework.ai.openai.client; + +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; + +public record OpenAiSseResponse( + + @JsonProperty("created") + Integer created, + + @JsonProperty("model") + String model, + + @JsonProperty("id") + String id, + + @JsonProperty("system_fingerprint") + String systemFingerprint, + + @JsonProperty("choices") + List choices, + + @JsonProperty("object") + String object +) { + public record Choice( + + @JsonProperty("finish_reason") + String finishReason, + + @JsonProperty("delta") + Delta delta, + + @JsonProperty("index") + Integer index + ) { + public record Delta( + + @JsonProperty("role") + String role, + + @JsonProperty("content") + String content + ) { + } + } +} \ No newline at end of file From 2ce8f39dc8c4297a0296270c338d463782aa8e42 Mon Sep 17 00:00:00 2001 From: Jemin Huh Date: Sun, 26 Nov 2023 16:37:16 +0900 Subject: [PATCH 3/3] Add spring-boot-starter-webflux dependency, Implement AiStreamClient interface extending AiClient, Implement OpenAiStreamClient using Reactor Flux --- spring-ai-openai/pom.xml | 5 + .../ai/openai/client/AiStreamClient.java | 33 +++ .../ai/openai/client/OpenAiStreamClient.java | 114 ++++++-- .../ai/openai/OpenAiTestConfiguration.java | 62 +++-- .../ai/openai/client/ClientIT.java | 256 ++++++++++-------- 5 files changed, 301 insertions(+), 169 deletions(-) create mode 100644 spring-ai-openai/src/main/java/org/springframework/ai/openai/client/AiStreamClient.java diff --git a/spring-ai-openai/pom.xml b/spring-ai-openai/pom.xml index be1a8f2a3df..51219d30aad 100644 --- a/spring-ai-openai/pom.xml +++ b/spring-ai-openai/pom.xml @@ -56,6 +56,11 @@ spring-boot-starter-logging + + org.springframework.boot + spring-boot-starter-webflux + + org.springframework.boot diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/AiStreamClient.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/AiStreamClient.java new file mode 100644 index 00000000000..3e2371bcffa --- /dev/null +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/AiStreamClient.java @@ -0,0 +1,33 @@ +/* + * Copyright 2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.openai.client; + +import org.springframework.ai.client.AiClient; +import org.springframework.ai.prompt.Prompt; +import org.springframework.ai.prompt.messages.UserMessage; +import reactor.core.publisher.Flux; + +public interface AiStreamClient extends AiClient { + + default Flux generateStream(String message) { + Prompt prompt = new Prompt(new UserMessage(message)); + return generateStream(prompt); + } + + Flux generateStream(Prompt prompt); + +} diff --git a/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiStreamClient.java b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiStreamClient.java index f6073800912..754101cf5eb 100644 --- a/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiStreamClient.java +++ b/spring-ai-openai/src/main/java/org/springframework/ai/openai/client/OpenAiStreamClient.java @@ -16,37 +16,111 @@ package org.springframework.ai.openai.client; -import com.theokanning.openai.completion.chat.ChatCompletionChunk; -import com.theokanning.openai.completion.chat.ChatCompletionRequest; -import com.theokanning.openai.completion.chat.ChatMessage; -import com.theokanning.openai.service.OpenAiService; -import io.reactivex.Flowable; + +import com.fasterxml.jackson.annotation.JsonInclude; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.databind.DeserializationFeature; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.client.AiResponse; +import org.springframework.ai.client.Generation; import org.springframework.ai.prompt.Prompt; -import org.springframework.ai.prompt.messages.Message; +import org.springframework.core.ParameterizedTypeReference; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; +import org.springframework.http.codec.ServerSentEvent; +import org.springframework.web.reactive.function.client.WebClient; +import reactor.core.publisher.Flux; import java.util.List; +import java.util.Map; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +public class OpenAiStreamClient implements AiStreamClient { + + private Double temperature = 0.7; + + private String model = "gpt-3.5-turbo"; + + private final Logger logger = LoggerFactory.getLogger(getClass()); + + private final WebClient webClient; + private final ObjectMapper objectMapper; + + private final ParameterizedTypeReference> sseType; + + public OpenAiStreamClient(String openAiApiToken) { + this("https://api.openai.com/", openAiApiToken); + } + + public OpenAiStreamClient(String openAiEndpoint, String openAiApiToken) { + this.webClient = WebClient.builder().baseUrl(openAiEndpoint) + .defaultHeader(HttpHeaders.CONTENT_TYPE, MediaType.APPLICATION_JSON_VALUE) + .defaultHeader(HttpHeaders.AUTHORIZATION, "Bearer " + openAiApiToken).build(); + this.objectMapper = new ObjectMapper().setSerializationInclusion(JsonInclude.Include.NON_NULL) + .disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES) + .configure(DeserializationFeature.ACCEPT_EMPTY_STRING_AS_NULL_OBJECT, true); + this.sseType = new ParameterizedTypeReference<>() {}; + } + + @Override + public AiResponse generate(Prompt prompt) { + List openAiChatMessages = prompt.getMessages().stream() + .map(message -> new OpenAiChatMessage.Builder().role(message.getMessageTypeValue()).content( + message.getContent()).build()).toList(); + + ChatCompletionsRequest chatCompletionsRequest = + new ChatCompletionsRequest.Builder().stream(false).model(this.model).temperature(this.temperature) + .messages(openAiChatMessages).build(); + + return getAiResponse(chatCompletionsRequest); + } -public class OpenAiStreamClient extends OpenAiClient { - private final OpenAiService openAiService; + private AiResponse getAiResponse(ChatCompletionsRequest chatCompletionsRequest) { + + logger.trace("ChatMessages: {}", chatCompletionsRequest.getMessages()); + + List chatCompletionChoices = + createChatCompletion(chatCompletionsRequest).bodyToMono(ChatCompletionResponse.class) + .map(ChatCompletionResponse::choices).block(); + + logger.trace("ChatCompletionChoice: {}", chatCompletionChoices); + + return new AiResponse(chatCompletionChoices.stream().map(ChatCompletionResponse.Choice::message) + .map(chatMessage -> new Generation(chatMessage.getContent(), Map.of("role", chatMessage.getRole()))) + .collect(Collectors.toList())); + } - public OpenAiStreamClient(OpenAiService openAiService) { - super(openAiService); - this.openAiService = openAiService; + private WebClient.ResponseSpec createChatCompletion(ChatCompletionsRequest chatCompletionsRequest) { + return this.webClient.post().uri("/v1/chat/completions") + .bodyValue(objectMapper.convertValue(chatCompletionsRequest, JsonNode.class)).retrieve(); } - public Flowable generateStream(Prompt prompt) { + @Override + public Flux generateStream(Prompt prompt) { - List messages = prompt.getMessages(); + List openAiChatMessages = prompt.getMessages().stream() + .map(message -> new OpenAiChatMessage.Builder().role(message.getMessageTypeValue()).content( + message.getContent()).build()).toList(); - List theoMessages = - messages.stream().map(message -> new ChatMessage(message.getMessageTypeValue(), message.getContent())) - .toList(); + ChatCompletionsRequest chatCompletionsRequest = + new ChatCompletionsRequest.Builder().stream(true).model(this.model).temperature(this.temperature) + .messages(openAiChatMessages).build(); - ChatCompletionRequest chatCompletionRequest = - ChatCompletionRequest.builder().model(getModel()).temperature(getTemperature()).messages(theoMessages) - .stream(true).build(); + logger.trace("ChatMessages: {}", chatCompletionsRequest.getMessages()); - return this.openAiService.streamChatCompletion(chatCompletionRequest); + return createChatCompletion(chatCompletionsRequest).bodyToFlux(sseType).map(ServerSentEvent::data) + .filter(Predicate.not("[DONE]"::equals)) + .handle((data, sink) -> { + try { + sink.next(objectMapper.readValue(data, OpenAiSseResponse.class)); + } catch (JsonProcessingException e) { + sink.error(new RuntimeException(e)); + } + }); } } diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java index 6b15ed07c87..391e330ce2f 100644 --- a/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiTestConfiguration.java @@ -14,35 +14,37 @@ @SpringBootConfiguration public class OpenAiTestConfiguration { - @Bean - public OpenAiService theoOpenAiService() { - String apiKey = System.getenv("OPENAI_API_KEY"); - if (!StringUtils.hasText(apiKey)) { - throw new IllegalArgumentException( - "You must provide an API key. Put it in an environment variable under the name OPENAI_API_KEY"); - } - OpenAiService openAiService = new OpenAiService(apiKey, Duration.ofSeconds(60)); - return openAiService; - } - - @Bean - public OpenAiClient openAiClient(OpenAiService theoOpenAiService) { - OpenAiClient openAiClient = new OpenAiClient(theoOpenAiService); - openAiClient.setTemperature(0.3); - return openAiClient; - } - - @Bean - public EmbeddingClient openAiEmbeddingClient(OpenAiService theoOpenAiService) { - return new OpenAiEmbeddingClient(theoOpenAiService); - } - - - @Bean - public OpenAiStreamClient openAiStreamClient(OpenAiService theoOpenAiService) { - OpenAiStreamClient OpenAiStreamClient = new OpenAiStreamClient(theoOpenAiService); - OpenAiStreamClient.setTemperature(0.3); - return OpenAiStreamClient; - } + @Bean + public OpenAiService theoOpenAiService() { + String apiKey = getApiKey(); + OpenAiService openAiService = new OpenAiService(apiKey, Duration.ofSeconds(60)); + return openAiService; + } + + private String getApiKey() { + String apiKey = System.getenv("OPENAI_API_KEY"); + if (!StringUtils.hasText(apiKey)) { + throw new IllegalArgumentException( + "You must provide an API key. Put it in an environment variable under the name OPENAI_API_KEY"); + } + return apiKey; + } + + @Bean + public OpenAiClient openAiClient(OpenAiService theoOpenAiService) { + OpenAiClient openAiClient = new OpenAiClient(theoOpenAiService); + openAiClient.setTemperature(0.3); + return openAiClient; + } + + @Bean + public EmbeddingClient openAiEmbeddingClient(OpenAiService theoOpenAiService) { + return new OpenAiEmbeddingClient(theoOpenAiService); + } + + @Bean + public OpenAiStreamClient openAiStreamClient() { + return new OpenAiStreamClient(getApiKey()); + } } diff --git a/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIT.java b/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIT.java index 9328e16ddbc..debf1d182f9 100644 --- a/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIT.java +++ b/spring-ai-openai/src/test/java/org/springframework/ai/openai/client/ClientIT.java @@ -1,8 +1,5 @@ package org.springframework.ai.openai.client; -import com.theokanning.openai.completion.chat.ChatCompletionChoice; -import com.theokanning.openai.completion.chat.ChatCompletionChunk; -import com.theokanning.openai.completion.chat.ChatMessage; import org.junit.jupiter.api.Test; import org.springframework.ai.client.AiResponse; import org.springframework.ai.client.Generation; @@ -30,124 +27,145 @@ @SpringBootTest class ClientIT extends AbstractIT { - @Value("classpath:/prompts/system-message.st") - private Resource systemResource; - - @Test - void roleTest() { - String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."; - String name = "Bob"; - String voice = "pirate"; - UserMessage userMessage = new UserMessage(request); - SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); - Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); - Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - AiResponse response = openAiClient.generate(prompt); - // needs fine tuning... evaluateQuestionAndAnswer(request, response, false); - } - - @Test - void outputParser() { - DefaultConversionService conversionService = new DefaultConversionService(); - ListOutputParser outputParser = new ListOutputParser(conversionService); - - String format = outputParser.getFormat(); - String template = """ - List five {subject} - {format} - """; - PromptTemplate promptTemplate = new PromptTemplate(template, - Map.of("subject", "ice cream flavors", "format", format)); - Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.openAiClient.generate(prompt).getGeneration(); - - List list = outputParser.parse(generation.getText()); - System.out.println(list); - assertThat(list).hasSize(5); - - } - - @Test - void mapOutputParser() { - MapOutputParser outputParser = new MapOutputParser(); - - String format = outputParser.getFormat(); - String template = """ - Provide me a List of {subject} - {format} - """; - PromptTemplate promptTemplate = new PromptTemplate(template, - Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); - Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = openAiClient.generate(prompt).getGeneration(); - - Map result = outputParser.parse(generation.getText()); - assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); - - } - - @Test - void beanOutputParser() { - - BeanOutputParser outputParser = new BeanOutputParser<>(ActorsFilms.class); - - String format = outputParser.getFormat(); - String template = """ - Generate the filmography for a random actor. - {format} - """; - PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); - Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = openAiClient.generate(prompt).getGeneration(); - - ActorsFilms actorsFilms = outputParser.parse(generation.getText()); - System.out.println(actorsFilms); - } - - record ActorsFilmsRecord(String actor, List movies) { - } - - @Test - void beanOutputParserRecords() { - - BeanOutputParser outputParser = new BeanOutputParser<>(ActorsFilmsRecord.class); - - String format = outputParser.getFormat(); - String template = """ - Generate the filmography of 5 movies for Tom Hanks. - {format} - """; - PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); - Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = openAiClient.generate(prompt).getGeneration(); - - ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getText()); - System.out.println(actorsFilms); - assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); - assertThat(actorsFilms.movies()).hasSize(5); - } - - @Test - void beanStreamOutputParserRecords() { - - BeanOutputParser outputParser = new BeanOutputParser<>(ActorsFilmsRecord.class); - - String format = outputParser.getFormat(); - String template = """ + @Value("classpath:/prompts/system-message.st") + private Resource systemResource; + + @Test + void roleTest() { + String request = "Tell me about 3 famous pirates from the Golden Age of Piracy and why they did."; + String name = "Bob"; + String voice = "pirate"; + UserMessage userMessage = new UserMessage(request); + SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); + Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", name, "voice", voice)); + Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); + AiResponse response = openAiClient.generate(prompt); + // needs fine tuning... evaluateQuestionAndAnswer(request, response, false); + } + + @Test + void outputParser() { + DefaultConversionService conversionService = new DefaultConversionService(); + ListOutputParser outputParser = new ListOutputParser(conversionService); + + String format = outputParser.getFormat(); + String template = """ + List five {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "ice cream flavors", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = this.openAiClient.generate(prompt).getGeneration(); + + List list = outputParser.parse(generation.getText()); + System.out.println(list); + assertThat(list).hasSize(5); + + } + + @Test + void mapOutputParser() { + MapOutputParser outputParser = new MapOutputParser(); + + String format = outputParser.getFormat(); + String template = """ + Provide me a List of {subject} + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, + Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = openAiClient.generate(prompt).getGeneration(); + + Map result = outputParser.parse(generation.getText()); + assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); + + } + + @Test + void beanOutputParser() { + + BeanOutputParser outputParser = new BeanOutputParser<>(ActorsFilms.class); + + String format = outputParser.getFormat(); + String template = """ + Generate the filmography for a random actor. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = openAiClient.generate(prompt).getGeneration(); + + ActorsFilms actorsFilms = outputParser.parse(generation.getText()); + System.out.println(actorsFilms); + } + + record ActorsFilmsRecord(String actor, List movies) { + } + + @Test + void beanOutputParserRecords() { + + BeanOutputParser outputParser = new BeanOutputParser<>(ActorsFilmsRecord.class); + + String format = outputParser.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = openAiClient.generate(prompt).getGeneration(); + + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getText()); + System.out.println(actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void beanMonoOutputParserRecords() { + + BeanOutputParser outputParser = new BeanOutputParser<>(ActorsFilmsRecord.class); + + String format = outputParser.getFormat(); + String template = """ + Generate the filmography of 5 movies for Tom Hanks. + {format} + """; + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + Generation generation = openAiStreamClient.generate(prompt).getGeneration(); + + ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getText()); + System.out.println(actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } + + @Test + void beanStreamOutputParserRecords() { + + BeanOutputParser outputParser = new BeanOutputParser<>(ActorsFilmsRecord.class); + + String format = outputParser.getFormat(); + String template = """ Generate the filmography of 5 movies for Tom Hanks. {format} """; - PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); - Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = openAiStreamClient.generateStream(prompt).toList().blockingGet().stream() - .map(ChatCompletionChunk::getChoices).flatMap(List::stream) - .map(ChatCompletionChoice::getMessage).map(ChatMessage::getContent) - .collect(Collectors.joining()); - - ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); - System.out.println(actorsFilms); - assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); - assertThat(actorsFilms.movies()).hasSize(5); - } + PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); + Prompt prompt = new Prompt(promptTemplate.createMessage()); + + String generationTextFromStream = + openAiStreamClient.generateStream(prompt).map(OpenAiSseResponse::choices) + .toStream().flatMap(List::stream).map(OpenAiSseResponse.Choice::delta) + .map(OpenAiSseResponse.Choice.Delta::content).collect(Collectors.joining()); + + ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream); + System.out.println(actorsFilms); + assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks"); + assertThat(actorsFilms.movies()).hasSize(5); + } }