diff --git a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatClientIT.java b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatClientIT.java index 7a86a48553..4fe481ea49 100644 --- a/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatClientIT.java +++ b/models/spring-ai-anthropic/src/test/java/org/springframework/ai/anthropic/AnthropicChatClientIT.java @@ -180,7 +180,7 @@ void multiModalityTest() throws IOException { byte[] imageData = new ClassPathResource("/test.png").getContentAsByteArray(); - var userMessage = new UserMessage("Explain what do you see o this picture?", + var userMessage = new UserMessage("Explain what do you see on this picture?", List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); ChatResponse response = chatClient.call(new Prompt(List.of(userMessage))); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java index 79c4b6498c..8e86ecdd19 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatClient.java @@ -15,6 +15,8 @@ */ package org.springframework.ai.openai; +import java.util.ArrayList; +import java.util.Base64; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -43,6 +45,7 @@ import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionFinishReason; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; +import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.MediaContent; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest; @@ -53,6 +56,7 @@ import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; +import org.springframework.util.MimeType; /** * {@link ChatClient} and {@link StreamingChatClient} implementation for {@literal OpenAI} @@ -240,11 +244,20 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { Set functionsForThisRequest = new HashSet<>(); - List chatCompletionMessages = prompt.getInstructions() - .stream() - .map(m -> new ChatCompletionMessage(m.getContent(), - ChatCompletionMessage.Role.valueOf(m.getMessageType().name()))) - .toList(); + List chatCompletionMessages = prompt.getInstructions().stream().map(m -> { + // Add text content. + List contents = new ArrayList<>(List.of(new MediaContent(m.getContent()))); + if (!CollectionUtils.isEmpty(m.getMedia())) { + // Add media content. + contents.addAll(m.getMedia() + .stream() + .map(media -> new MediaContent( + new MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData())))) + .toList()); + } + + return new ChatCompletionMessage(contents, ChatCompletionMessage.Role.valueOf(m.getMessageType().name())); + }).toList(); ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream); @@ -286,6 +299,22 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) { return request; } + private String fromMediaData(MimeType mimeType, Object mediaContentData) { + if (mediaContentData instanceof byte[] bytes) { + // Assume the bytes are an image. So, convert the bytes to a base64 encoded + // following the prefix pattern. + return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes)); + } + else if (mediaContentData instanceof String text) { + // Assume the text is a URLs or a base64 encoded image prefixed by the user. + return text; + } + else { + throw new IllegalArgumentException( + "Unsupported media data type: " + mediaContentData.getClass().getSimpleName()); + } + } + private List getFunctionTools(Set functionNames) { return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> { var function = new OpenAiApi.FunctionTool.Function(functionCallback.getDescription(), diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java index 0ad2bcf0fb..5edfdc9e7c 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/api/OpenAiApi.java @@ -407,10 +407,12 @@ public record ResponseFormat( /** * Message comprising the conversation. * - * @param content The contents of the message. + * @param rawContent The contents of the message. Can be either a {@link MediaContent} or a {@link String}. + * The response message content is always a {@link String}. * @param role The role of the messages author. Could be one of the {@link Role} types. * @param name An optional name for the participant. Provides the model information to differentiate between - * participants of the same role. + * participants of the same role. In case of Function calling, the name is the function name that the message is + * responding to. * @param toolCallId Tool call that this message is responding to. Only applicable for the {@link Role#TOOL} role * and null otherwise. * @param toolCalls The tool calls generated by the model, such as function calls. Applicable only for @@ -418,18 +420,31 @@ public record ResponseFormat( */ @JsonInclude(Include.NON_NULL) public record ChatCompletionMessage( - @JsonProperty("content") String content, + @JsonProperty("content") Object rawContent, @JsonProperty("role") Role role, @JsonProperty("name") String name, @JsonProperty("tool_call_id") String toolCallId, @JsonProperty("tool_calls") List toolCalls) { + /** + * Get message content as String. + */ + public String content() { + if (this.rawContent == null) { + return null; + } + if (this.rawContent instanceof String text) { + return text; + } + throw new IllegalStateException("The content is not a string!"); + } + /** * Create a chat completion message with the given content and role. All other fields are null. * @param content The contents of the message. * @param role The role of the author of this message. */ - public ChatCompletionMessage(String content, Role role) { + public ChatCompletionMessage(Object content, Role role) { this(content, role, null, null, null); } @@ -455,6 +470,54 @@ public enum Role { @JsonProperty("tool") TOOL } + /** + * An array of content parts with a defined type. + * Each MediaContent can be of either "text" or "image_url" type. Not both. + * + * @param type Content type, each can be of type text or image_url. + * @param text The text content of the message. + * @param imageUrl The image content of the message. You can pass multiple + * images by adding multiple image_url content parts. Image input is only + * supported when using the gpt-4-visual-preview model. + */ + @JsonInclude(Include.NON_NULL) + public record MediaContent( + @JsonProperty("type") String type, + @JsonProperty("text") String text, + @JsonProperty("image_url") ImageUrl imageUrl) { + + /** + * @param url Either a URL of the image or the base64 encoded image data. + * The base64 encoded image data must have a special prefix in the following format: + * "data:{mimetype};base64,{base64-encoded-image-data}". + * @param detail Specifies the detail level of the image. + */ + @JsonInclude(Include.NON_NULL) + public record ImageUrl( + @JsonProperty("url") String url, + @JsonProperty("detail") String detail) { + + public ImageUrl(String url) { + this(url, null); + } + } + + /** + * Shortcut constructor for a text content. + * @param text The text content of the message. + */ + public MediaContent(String text) { + this("text", text, null); + } + + /** + * Shortcut constructor for an image content. + * @param imageUrl The image content of the message. + */ + public MediaContent(ImageUrl imageUrl) { + this("image_url", null, imageUrl); + } + } /** * The relevant tool call. * @@ -483,6 +546,13 @@ public record ChatCompletionFunction( } } + public static String getTextContent(List content) { + return content.stream() + .filter(c -> "text".equals(c.type())) + .map(ChatCompletionMessage.MediaContent::text) + .reduce("", (a, b) -> a + b); + } + /** * The reason the model stopped generating tokens. */ diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java index 182a11dcec..25e553c1bf 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/api/OpenAiApiIT.java @@ -21,7 +21,6 @@ import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import reactor.core.publisher.Flux; -import org.springframework.ai.openai.api.OpenAiApi; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk; import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage; diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechClientIT.java index 72b982cd6b..5fdc3e24b8 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/speech/OpenAiSpeechClientIT.java @@ -38,8 +38,7 @@ class OpenAiSpeechClientIT extends AbstractIT { @Test void shouldSuccessfullyStreamAudioBytesForEmptyMessage() { - Flux response = openAiAudioSpeechClient - .stream("Today is a wonderful day to build something people love!"); + Flux response = speechClient.stream("Today is a wonderful day to build something people love!"); assertThat(response).isNotNull(); assertThat(response.collectList().block()).isNotNull(); System.out.println(response.collectList().block()); @@ -47,7 +46,7 @@ void shouldSuccessfullyStreamAudioBytesForEmptyMessage() { @Test void shouldProduceAudioBytesDirectlyFromMessage() { - byte[] audioBytes = openAiAudioSpeechClient.call("Today is a wonderful day to build something people love!"); + byte[] audioBytes = speechClient.call("Today is a wonderful day to build something people love!"); assertThat(audioBytes).hasSizeGreaterThan(0); } @@ -62,7 +61,7 @@ void shouldGenerateNonEmptyMp3AudioFromSpeechPrompt() { .build(); SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", speechOptions); - SpeechResponse response = openAiAudioSpeechClient.call(speechPrompt); + SpeechResponse response = speechClient.call(speechPrompt); byte[] audioBytes = response.getResult().getOutput(); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput()).isNotEmpty(); @@ -80,7 +79,7 @@ void speechRateLimitTest() { .build(); SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", speechOptions); - SpeechResponse response = openAiAudioSpeechClient.call(speechPrompt); + SpeechResponse response = speechClient.call(speechPrompt); OpenAiAudioSpeechResponseMetadata metadata = response.getMetadata(); assertThat(metadata).isNotNull(); assertThat(metadata.getRateLimit()).isNotNull(); @@ -101,7 +100,7 @@ void shouldStreamNonEmptyResponsesForValidSpeechPrompts() { SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!", speechOptions); - Flux responseFlux = openAiAudioSpeechClient.stream(speechPrompt); + Flux responseFlux = speechClient.stream(speechPrompt); assertThat(responseFlux).isNotNull(); List responses = responseFlux.collectList().block(); assertThat(responses).isNotNull(); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionClientIT.java index 9c0845deea..dcb10cd102 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/audio/transcription/OpenAiTranscriptionClientIT.java @@ -43,7 +43,7 @@ void transcriptionTest() { .withTemperature(0f) .build(); AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); - AudioTranscriptionResponse response = openAiTranscriptionClient.call(transcriptionRequest); + AudioTranscriptionResponse response = transcriptionClient.call(transcriptionRequest); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue(); } @@ -59,7 +59,7 @@ void transcriptionTestWithOptions() { .withResponseFormat(responseFormat) .build(); AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions); - AudioTranscriptionResponse response = openAiTranscriptionClient.call(transcriptionRequest); + AudioTranscriptionResponse response = transcriptionClient.call(transcriptionRequest); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue(); } diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java index b23af540b3..82f034ffac 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/chat/OpenAiChatClientIT.java @@ -15,6 +15,7 @@ */ package org.springframework.ai.openai.chat; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -30,6 +31,7 @@ import org.springframework.ai.chat.ChatResponse; import org.springframework.ai.chat.Generation; import org.springframework.ai.chat.messages.AssistantMessage; +import org.springframework.ai.chat.messages.Media; import org.springframework.ai.chat.messages.Message; import org.springframework.ai.chat.messages.UserMessage; import org.springframework.ai.chat.prompt.Prompt; @@ -47,7 +49,9 @@ import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.test.context.SpringBootTest; import org.springframework.core.convert.support.DefaultConversionService; +import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.Resource; +import org.springframework.util.MimeTypeUtils; import static org.assertj.core.api.Assertions.assertThat; @@ -67,7 +71,7 @@ void roleTest() { SystemPromptTemplate systemPromptTemplate = new SystemPromptTemplate(systemResource); Message systemMessage = systemPromptTemplate.createMessage(Map.of("name", "Bob", "voice", "pirate")); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - ChatResponse response = openAiChatClient.call(prompt); + ChatResponse response = chatClient.call(prompt); assertThat(response.getResults()).hasSize(1); assertThat(response.getResults().get(0).getOutput().getContent()).contains("Blackbeard"); // needs fine tuning... evaluateQuestionAndAnswer(request, response, false); @@ -86,7 +90,7 @@ void outputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "ice cream flavors", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = this.openAiChatClient.call(prompt).getResult(); + Generation generation = this.chatClient.call(prompt).getResult(); List list = outputParser.parse(generation.getOutput().getContent()); assertThat(list).hasSize(5); @@ -105,7 +109,7 @@ void mapOutputParser() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("subject", "an array of numbers from 1 to 9 under they key name 'numbers'", "format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = openAiChatClient.call(prompt).getResult(); + Generation generation = chatClient.call(prompt).getResult(); Map result = outputParser.parse(generation.getOutput().getContent()); assertThat(result.get("numbers")).isEqualTo(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9)); @@ -124,7 +128,7 @@ void beanOutputParser() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = openAiChatClient.call(prompt).getResult(); + Generation generation = chatClient.call(prompt).getResult(); ActorsFilms actorsFilms = outputParser.parse(generation.getOutput().getContent()); } @@ -144,7 +148,7 @@ void beanOutputParserRecords() { """; PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - Generation generation = openAiChatClient.call(prompt).getResult(); + Generation generation = chatClient.call(prompt).getResult(); ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent()); logger.info("" + actorsFilms); @@ -165,7 +169,7 @@ void beanStreamOutputParserRecords() { PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format)); Prompt prompt = new Prompt(promptTemplate.createMessage()); - String generationTextFromStream = openStreamingChatClient.stream(prompt) + String generationTextFromStream = streamingChatClient.stream(prompt) .collectList() .block() .stream() @@ -197,7 +201,7 @@ void functionCallTest() { .build())) .build(); - ChatResponse response = openAiChatClient.call(new Prompt(messages, promptOptions)); + ChatResponse response = chatClient.call(new Prompt(messages, promptOptions)); logger.info("Response: {}", response); @@ -222,7 +226,7 @@ void streamFunctionCallTest() { .build())) .build(); - Flux response = openStreamingChatClient.stream(new Prompt(messages, promptOptions)); + Flux response = streamingChatClient.stream(new Prompt(messages, promptOptions)); String content = response.collectList() .block() @@ -239,4 +243,55 @@ void streamFunctionCallTest() { assertThat(content).containsAnyOf("15.0", "15"); } + @Test + void multiModalityEmbeddedImage() throws IOException { + + byte[] imageData = new ClassPathResource("/test.png").getContentAsByteArray(); + + var userMessage = new UserMessage("Explain what do you see on this picture?", + List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); + + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()).build())); + + logger.info(response.getResult().getOutput().getContent()); + assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "bowl"); + } + + @Test + void multiModalityImageUrl() throws IOException { + + var userMessage = new UserMessage("Explain what do you see on this picture?", + List.of(new Media(MimeTypeUtils.IMAGE_PNG, + "https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png"))); + + ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()).build())); + + logger.info(response.getResult().getOutput().getContent()); + assertThat(response.getResult().getOutput().getContent()).contains("bananas", "apple", "bowl"); + } + + @Test + void streamingMultiModalityImageUrl() throws IOException { + + var userMessage = new UserMessage("Explain what do you see on this picture?", + List.of(new Media(MimeTypeUtils.IMAGE_PNG, + "https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png"))); + + Flux response = streamingChatClient.stream(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()).build())); + + String content = response.collectList() + .block() + .stream() + .map(ChatResponse::getResults) + .flatMap(List::stream) + .map(Generation::getOutput) + .map(AssistantMessage::getContent) + .collect(Collectors.joining()); + logger.info("Response: {}", content); + assertThat(content).contains("bananas", "apple", "bowl"); + } + } \ No newline at end of file diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java index 1a36b449a7..6f1c0c4e1b 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/image/OpenAiImageClientIT.java @@ -39,7 +39,7 @@ void imageAsUrlTest() { ImagePrompt imagePrompt = new ImagePrompt(instructions, options); - ImageResponse imageResponse = openaiImageClient.call(imagePrompt); + ImageResponse imageResponse = imageClient.call(imagePrompt); assertThat(imageResponse.getResults()).hasSize(1); diff --git a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java index 9c2b8780a4..89155e9e80 100644 --- a/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java +++ b/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/testutils/AbstractIT.java @@ -43,19 +43,19 @@ public abstract class AbstractIT { private static final Logger logger = LoggerFactory.getLogger(AbstractIT.class); @Autowired - protected ChatClient openAiChatClient; + protected ChatClient chatClient; @Autowired - protected OpenAiAudioTranscriptionClient openAiTranscriptionClient; + protected StreamingChatClient streamingChatClient; @Autowired - protected OpenAiAudioSpeechClient openAiAudioSpeechClient; + protected OpenAiAudioTranscriptionClient transcriptionClient; @Autowired - protected ImageClient openaiImageClient; + protected OpenAiAudioSpeechClient speechClient; @Autowired - protected StreamingChatClient openStreamingChatClient; + protected ImageClient imageClient; @Value("classpath:/prompts/eval/qa-evaluator-accurate-answer.st") protected Resource qaEvaluatorAccurateAnswerResource; @@ -64,7 +64,7 @@ public abstract class AbstractIT { protected Resource qaEvaluatorNotRelatedResource; @Value("classpath:/prompts/eval/qa-evaluator-fact-based-answer.st") - protected Resource qaEvalutaorFactBasedAnswerResource; + protected Resource qaEvaluatorFactBasedAnswerResource; @Value("classpath:/prompts/eval/user-evaluator-message.st") protected Resource userEvaluatorResource; @@ -78,19 +78,19 @@ protected void evaluateQuestionAndAnswer(String question, ChatResponse response, Map.of("question", question, "answer", answer)); SystemMessage systemMessage; if (factBased) { - systemMessage = new SystemMessage(qaEvalutaorFactBasedAnswerResource); + systemMessage = new SystemMessage(qaEvaluatorFactBasedAnswerResource); } else { systemMessage = new SystemMessage(qaEvaluatorAccurateAnswerResource); } Message userMessage = userPromptTemplate.createMessage(); Prompt prompt = new Prompt(List.of(userMessage, systemMessage)); - String yesOrNo = openAiChatClient.call(prompt).getResult().getOutput().getContent(); + String yesOrNo = chatClient.call(prompt).getResult().getOutput().getContent(); logger.info("Is Answer related to question: " + yesOrNo); if (yesOrNo.equalsIgnoreCase("no")) { SystemMessage notRelatedSystemMessage = new SystemMessage(qaEvaluatorNotRelatedResource); prompt = new Prompt(List.of(userMessage, notRelatedSystemMessage)); - String reasonForFailure = openAiChatClient.call(prompt).getResult().getOutput().getContent(); + String reasonForFailure = chatClient.call(prompt).getResult().getOutput().getContent(); fail(reasonForFailure); } else { diff --git a/models/spring-ai-openai/src/test/resources/test.png b/models/spring-ai-openai/src/test/resources/test.png new file mode 100644 index 0000000000..8abb4c81ae Binary files /dev/null and b/models/spring-ai-openai/src/test/resources/test.png differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/openai-chat-api.jpg b/spring-ai-docs/src/main/antora/modules/ROOT/images/openai-chat-api.jpg new file mode 100644 index 0000000000..e0c83ae3ea Binary files /dev/null and b/spring-ai-docs/src/main/antora/modules/ROOT/images/openai-chat-api.jpg differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/images/openai-chat-api.png b/spring-ai-docs/src/main/antora/modules/ROOT/images/openai-chat-api.png deleted file mode 100644 index 83e9f4351d..0000000000 Binary files a/spring-ai-docs/src/main/antora/modules/ROOT/images/openai-chat-api.png and /dev/null differ diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc index e098ba29ac..e354042c8a 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/chat/openai-chat.adoc @@ -141,6 +141,57 @@ You can register custom Java functions with the OpenAiChatClient and have the Op This is a powerful technique to connect the LLM capabilities with external tools and APIs. Read more about xref:api/chat/functions/openai-chat-functions.adoc[OpenAI Function Calling]. +== Multimodal + +Multimodality refers to a model's ability to simultaneously understand and process information from various sources, including text, images, audio, and other data formats. +Presently, the OpenAI `gpt-4-visual-preview` model offers multimodal support. Refer to the link:https://platform.openai.com/docs/guides/vision[Vision] guide for more information. + +The OpenAI link:https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages[User Message API] can incorporate a list of base64-encoded images or image urls with the message. +Spring AI’s link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Message.java[Message] interface facilitates multimodal AI models by introducing the link:https://github.com/spring-projects/spring-ai/blob/main/spring-ai-core/src/main/java/org/springframework/ai/chat/messages/Media.java[Media] type. +This type encompasses data and details regarding media attachments in messages, utilizing Spring’s `org.springframework.util.MimeType` and a `java.lang.Object` for the raw media data. + +Below is a code example excerpted from link:https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-openai/src/test/java/org/springframework/ai/openai/OpenAiChatClientIT.java[OpenAiChatClientIT.java], illustrating the fusion of user text with an image. + +[source,java] +---- +byte[] imageData = new ClassPathResource("/multimodal.test.png").getContentAsByteArray(); + +var userMessage = new UserMessage("Explain what do you see on this picture?", + List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData))); + +ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()).build())); +---- + +or the image URL equivalent: + +[source,java] +---- +var userMessage = new UserMessage("Explain what do you see on this picture?", + List.of(new Media(MimeTypeUtils.IMAGE_PNG, + "https://docs.spring.io/spring-ai/reference/1.0-SNAPSHOT/_images/multimodal.test.png"))); + +ChatResponse response = chatClient.call(new Prompt(List.of(userMessage), + OpenAiChatOptions.builder().withModel(OpenAiApi.ChatModel.GPT_4_VISION_PREVIEW.getValue()).build())); +---- + +TIP: you can pass multiple images as well. + +It takes as an input the `multimodal.test.png` image: + +image::multimodal.test.png[Multimodal Test Image, 200, 200, align="left"] + +along with the text message "Explain what do you see on this picture?", and generates a response like this: + +---- +This is an image of a fruit bowl with a simple design. The bowl is made of metal with curved wire edges that +create an open structure, allowing the fruit to be visible from all angles. Inside the bowl, there are two +yellow bananas resting on top of what appears to be a red apple. The bananas are slightly overripe, as +indicated by the brown spots on their peels. The bowl has a metal ring at the top, likely to serve as a handle +for carrying. The bowl is placed on a flat surface with a neutral-colored background that provides a clear +view of the fruit inside. +---- + == Sample Controller https://start.spring.io/[Create] a new Spring Boot project and add the `spring-ai-openai-spring-boot-starter` to your pom (or gradle) dependencies. @@ -239,7 +290,7 @@ The https://github.com/spring-projects/spring-ai/blob/main/models/spring-ai-open Following class diagram illustrates the `OpenAiApi` chat interfaces and building blocks: -image::openai-chat-api.png[OpenAiApi Chat API Diagram] +image::openai-chat-api.jpg[OpenAiApi Chat API Diagram, width=1000, align="center"] Here is a simple snippet how to use the api programmatically: