Skip to content

Commit

Permalink
Add Multi-Modality Support for OpenAI
Browse files Browse the repository at this point in the history
 - Implemented a MediaContent abstraction within the OpenAiApi to handle text and image inputs.
 - Response message content remains a plain String, ensuring backward compatibility.
 - Extended the OpenAiChatClient request creation process to seamlessly map Spring AI Messages with
   Media content to the low-level OpenAiApi MediaContent types.
 - Added integration tests for embedded and URL images, covering both synchronous and streaming calls.
 - Updated the OpenAI class diagram to reflect the new media content types provided by the OpenAI API.
 - Incorporated a chapter on multi-modality within the openai-chat.adoc documentation.
 - Improve the openai multimoality doc
  • Loading branch information
tzolov authored and markpollack committed Mar 21, 2024
1 parent b8f773c commit 834d2d0
Show file tree
Hide file tree
Showing 13 changed files with 241 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ void multiModalityTest() throws IOException {

byte[] imageData = new ClassPathResource("/test.png").getContentAsByteArray();

var userMessage = new UserMessage("Explain what do you see o this picture?",
var userMessage = new UserMessage("Explain what do you see on this picture?",
List.of(new Media(MimeTypeUtils.IMAGE_PNG, imageData)));

ChatResponse response = chatClient.call(new Prompt(List.of(userMessage)));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
*/
package org.springframework.ai.openai;

import java.util.ArrayList;
import java.util.Base64;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
Expand Down Expand Up @@ -43,6 +45,7 @@
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion.Choice;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionFinishReason;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.MediaContent;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.Role;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage.ToolCall;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionRequest;
Expand All @@ -53,6 +56,7 @@
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;

/**
* {@link ChatClient} and {@link StreamingChatClient} implementation for {@literal OpenAI}
Expand Down Expand Up @@ -240,11 +244,20 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {

Set<String> functionsForThisRequest = new HashSet<>();

List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions()
.stream()
.map(m -> new ChatCompletionMessage(m.getContent(),
ChatCompletionMessage.Role.valueOf(m.getMessageType().name())))
.toList();
List<ChatCompletionMessage> chatCompletionMessages = prompt.getInstructions().stream().map(m -> {
// Add text content.
List<MediaContent> contents = new ArrayList<>(List.of(new MediaContent(m.getContent())));
if (!CollectionUtils.isEmpty(m.getMedia())) {
// Add media content.
contents.addAll(m.getMedia()
.stream()
.map(media -> new MediaContent(
new MediaContent.ImageUrl(this.fromMediaData(media.getMimeType(), media.getData()))))
.toList());
}

return new ChatCompletionMessage(contents, ChatCompletionMessage.Role.valueOf(m.getMessageType().name()));
}).toList();

ChatCompletionRequest request = new ChatCompletionRequest(chatCompletionMessages, stream);

Expand Down Expand Up @@ -286,6 +299,22 @@ ChatCompletionRequest createRequest(Prompt prompt, boolean stream) {
return request;
}

private String fromMediaData(MimeType mimeType, Object mediaContentData) {
if (mediaContentData instanceof byte[] bytes) {
// Assume the bytes are an image. So, convert the bytes to a base64 encoded
// following the prefix pattern.
return String.format("data:%s;base64,%s", mimeType.toString(), Base64.getEncoder().encodeToString(bytes));
}
else if (mediaContentData instanceof String text) {
// Assume the text is a URLs or a base64 encoded image prefixed by the user.
return text;
}
else {
throw new IllegalArgumentException(
"Unsupported media data type: " + mediaContentData.getClass().getSimpleName());
}
}

private List<OpenAiApi.FunctionTool> getFunctionTools(Set<String> functionNames) {
return this.resolveFunctionCallbacks(functionNames).stream().map(functionCallback -> {
var function = new OpenAiApi.FunctionTool.Function(functionCallback.getDescription(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,29 +407,44 @@ public record ResponseFormat(
/**
* Message comprising the conversation.
*
* @param content The contents of the message.
* @param rawContent The contents of the message. Can be either a {@link MediaContent} or a {@link String}.
* The response message content is always a {@link String}.
* @param role The role of the messages author. Could be one of the {@link Role} types.
* @param name An optional name for the participant. Provides the model information to differentiate between
* participants of the same role.
* participants of the same role. In case of Function calling, the name is the function name that the message is
* responding to.
* @param toolCallId Tool call that this message is responding to. Only applicable for the {@link Role#TOOL} role
* and null otherwise.
* @param toolCalls The tool calls generated by the model, such as function calls. Applicable only for
* {@link Role#ASSISTANT} role and null otherwise.
*/
@JsonInclude(Include.NON_NULL)
public record ChatCompletionMessage(
@JsonProperty("content") String content,
@JsonProperty("content") Object rawContent,
@JsonProperty("role") Role role,
@JsonProperty("name") String name,
@JsonProperty("tool_call_id") String toolCallId,
@JsonProperty("tool_calls") List<ToolCall> toolCalls) {

/**
* Get message content as String.
*/
public String content() {
if (this.rawContent == null) {
return null;
}
if (this.rawContent instanceof String text) {
return text;
}
throw new IllegalStateException("The content is not a string!");
}

/**
* Create a chat completion message with the given content and role. All other fields are null.
* @param content The contents of the message.
* @param role The role of the author of this message.
*/
public ChatCompletionMessage(String content, Role role) {
public ChatCompletionMessage(Object content, Role role) {
this(content, role, null, null, null);
}

Expand All @@ -455,6 +470,54 @@ public enum Role {
@JsonProperty("tool") TOOL
}

/**
* An array of content parts with a defined type.
* Each MediaContent can be of either "text" or "image_url" type. Not both.
*
* @param type Content type, each can be of type text or image_url.
* @param text The text content of the message.
* @param imageUrl The image content of the message. You can pass multiple
* images by adding multiple image_url content parts. Image input is only
* supported when using the gpt-4-visual-preview model.
*/
@JsonInclude(Include.NON_NULL)
public record MediaContent(
@JsonProperty("type") String type,
@JsonProperty("text") String text,
@JsonProperty("image_url") ImageUrl imageUrl) {

/**
* @param url Either a URL of the image or the base64 encoded image data.
* The base64 encoded image data must have a special prefix in the following format:
* "data:{mimetype};base64,{base64-encoded-image-data}".
* @param detail Specifies the detail level of the image.
*/
@JsonInclude(Include.NON_NULL)
public record ImageUrl(
@JsonProperty("url") String url,
@JsonProperty("detail") String detail) {

public ImageUrl(String url) {
this(url, null);
}
}

/**
* Shortcut constructor for a text content.
* @param text The text content of the message.
*/
public MediaContent(String text) {
this("text", text, null);
}

/**
* Shortcut constructor for an image content.
* @param imageUrl The image content of the message.
*/
public MediaContent(ImageUrl imageUrl) {
this("image_url", null, imageUrl);
}
}
/**
* The relevant tool call.
*
Expand Down Expand Up @@ -483,6 +546,13 @@ public record ChatCompletionFunction(
}
}

public static String getTextContent(List<ChatCompletionMessage.MediaContent> content) {
return content.stream()
.filter(c -> "text".equals(c.type()))
.map(ChatCompletionMessage.MediaContent::text)
.reduce("", (a, b) -> a + b);
}

/**
* The reason the model stopped generating tokens.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import reactor.core.publisher.Flux;

import org.springframework.ai.openai.api.OpenAiApi;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletion;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionChunk;
import org.springframework.ai.openai.api.OpenAiApi.ChatCompletionMessage;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,15 @@ class OpenAiSpeechClientIT extends AbstractIT {

@Test
void shouldSuccessfullyStreamAudioBytesForEmptyMessage() {
Flux<byte[]> response = openAiAudioSpeechClient
.stream("Today is a wonderful day to build something people love!");
Flux<byte[]> response = speechClient.stream("Today is a wonderful day to build something people love!");
assertThat(response).isNotNull();
assertThat(response.collectList().block()).isNotNull();
System.out.println(response.collectList().block());
}

@Test
void shouldProduceAudioBytesDirectlyFromMessage() {
byte[] audioBytes = openAiAudioSpeechClient.call("Today is a wonderful day to build something people love!");
byte[] audioBytes = speechClient.call("Today is a wonderful day to build something people love!");
assertThat(audioBytes).hasSizeGreaterThan(0);

}
Expand All @@ -62,7 +61,7 @@ void shouldGenerateNonEmptyMp3AudioFromSpeechPrompt() {
.build();
SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
speechOptions);
SpeechResponse response = openAiAudioSpeechClient.call(speechPrompt);
SpeechResponse response = speechClient.call(speechPrompt);
byte[] audioBytes = response.getResult().getOutput();
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getOutput()).isNotEmpty();
Expand All @@ -80,7 +79,7 @@ void speechRateLimitTest() {
.build();
SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
speechOptions);
SpeechResponse response = openAiAudioSpeechClient.call(speechPrompt);
SpeechResponse response = speechClient.call(speechPrompt);
OpenAiAudioSpeechResponseMetadata metadata = response.getMetadata();
assertThat(metadata).isNotNull();
assertThat(metadata.getRateLimit()).isNotNull();
Expand All @@ -101,7 +100,7 @@ void shouldStreamNonEmptyResponsesForValidSpeechPrompts() {

SpeechPrompt speechPrompt = new SpeechPrompt("Today is a wonderful day to build something people love!",
speechOptions);
Flux<SpeechResponse> responseFlux = openAiAudioSpeechClient.stream(speechPrompt);
Flux<SpeechResponse> responseFlux = speechClient.stream(speechPrompt);
assertThat(responseFlux).isNotNull();
List<SpeechResponse> responses = responseFlux.collectList().block();
assertThat(responses).isNotNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ void transcriptionTest() {
.withTemperature(0f)
.build();
AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions);
AudioTranscriptionResponse response = openAiTranscriptionClient.call(transcriptionRequest);
AudioTranscriptionResponse response = transcriptionClient.call(transcriptionRequest);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue();
}
Expand All @@ -59,7 +59,7 @@ void transcriptionTestWithOptions() {
.withResponseFormat(responseFormat)
.build();
AudioTranscriptionPrompt transcriptionRequest = new AudioTranscriptionPrompt(audioFile, transcriptionOptions);
AudioTranscriptionResponse response = openAiTranscriptionClient.call(transcriptionRequest);
AudioTranscriptionResponse response = transcriptionClient.call(transcriptionRequest);
assertThat(response.getResults()).hasSize(1);
assertThat(response.getResults().get(0).getOutput().toLowerCase().contains("fellow")).isTrue();
}
Expand Down
Loading

0 comments on commit 834d2d0

Please sign in to comment.