Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add missing response fields to Spring AI ChatResponse, Generation and AssistantMessage #550

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -206,12 +206,26 @@ private ChatResponse toChatResponse(ChatCompletion chatCompletion) {
return new ChatResponse(List.of());
}

List<Generation> generations = chatCompletion.content().stream().map(content -> {
return new Generation(content.text(), Map.of())
.withGenerationMetadata(ChatGenerationMetadata.from(chatCompletion.stopReason(), null));
}).toList();
var id = chatCompletion.id();
int index = 0;

return new ChatResponse(generations, AnthropicChatResponseMetadata.from(chatCompletion));
boolean isCompleted = ("message_stop".equals(chatCompletion.type()) || "message".equals(chatCompletion.type()))
&& (chatCompletion.stopReason() != null);
if (isCompleted) {
logger.info("Chat completion is completed: " + chatCompletion.stopReason());
}

// "Currently, the only type in responses is 'text'."
// https://docs.anthropic.com/claude/reference/messages_post
String aggregatedText = chatCompletion.content()
.stream()
.map(content -> content.text())
.collect(Collectors.joining());

Generation generation = new Generation(id, index, isCompleted, aggregatedText, Map.of(),
ChatGenerationMetadata.from(chatCompletion.stopReason(), null));

return new ChatResponse(id, List.of(generation), AnthropicChatResponseMetadata.from(chatCompletion));
}

private String fromMediaData(Object mediaData) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.ChatResponse;
Expand Down Expand Up @@ -79,10 +80,22 @@ void roleTest() {
assertThat(response.getMetadata().getUsage().getTotalTokens())
.isEqualTo(response.getMetadata().getUsage().getPromptTokens()
+ response.getMetadata().getUsage().getGenerationTokens());

Generation generation = response.getResults().get(0);
assertThat(generation.getOutput().getContent()).contains("Blackbeard");
assertThat(generation.getMetadata().getFinishReason()).isEqualTo("end_turn");
logger.info(response.toString());

assertThat(response.getId()).isNotBlank();

assertThat(generation.getIndex()).isEqualTo(0);
assertThat(generation.isCompleted()).isTrue();

AssistantMessage assistantMessage = generation.getOutput();
assertThat(assistantMessage.getId()).isEqualTo(response.getId());
assertThat(assistantMessage.getIndex()).isEqualTo(generation.getIndex());
assertThat(assistantMessage.isCompleted()).isTrue();

}

@Test
Expand Down Expand Up @@ -158,11 +171,11 @@ void beanStreamOutputParserRecords() {
""";
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());
Flux<ChatResponse> response = streamingChatClient.stream(prompt);

List<ChatResponse> chatResponseList = response.collectList().block();

String generationTextFromStream = streamingChatClient.stream(prompt)
.collectList()
.block()
.stream()
String generationTextFromStream = chatResponseList.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
Expand All @@ -173,6 +186,41 @@ void beanStreamOutputParserRecords() {
logger.info("" + actorsFilms);
assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
assertThat(actorsFilms.movies()).hasSize(5);

var firstResponse = chatResponseList.get(0);

for (int i = 0; i < chatResponseList.size() - 1; i++) {
var responseX = chatResponseList.get(i);
assertThat(responseX.getId()).isEqualTo(firstResponse.getId());

assertThat(responseX.getResults()).hasSize(1);
var generation = responseX.getResults().get(0);

assertThat(generation.getId()).isEqualTo(firstResponse.getId());
assertThat(generation.getIndex()).isEqualTo(0);
assertThat(generation.isCompleted()).isFalse();

AssistantMessage assistantMessage = generation.getOutput();

assertThat(assistantMessage.getId()).isEqualTo(firstResponse.getId());
assertThat(assistantMessage.getIndex()).isEqualTo(0);
assertThat(assistantMessage.isCompleted()).isFalse();
}

var lastResponse = chatResponseList.get(chatResponseList.size() - 1);
assertThat(lastResponse.getId()).isEqualTo(firstResponse.getId());
assertThat(lastResponse.getResults()).hasSize(1);
var lastGeneration = lastResponse.getResults().get(0);

assertThat(lastGeneration.getId()).isEqualTo(firstResponse.getId());
assertThat(lastGeneration.getIndex()).isEqualTo(0);
assertThat(lastGeneration.isCompleted()).isTrue();

AssistantMessage lastAssistantMessage = lastGeneration.getOutput();

assertThat(lastAssistantMessage.getId()).isEqualTo(firstResponse.getId());
assertThat(lastAssistantMessage.getIndex()).isEqualTo(0);
assertThat(lastAssistantMessage.isCompleted()).isTrue();
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import com.azure.ai.openai.OpenAIClient;
Expand All @@ -37,6 +38,8 @@
import com.azure.ai.openai.models.CompletionsFinishReason;
import com.azure.ai.openai.models.ContentFilterResultsForPrompt;
import com.azure.ai.openai.models.FunctionDefinition;
import com.azure.ai.openai.models.MaxTokensFinishDetails;
import com.azure.ai.openai.models.StopFinishDetails;
import com.azure.core.util.BinaryData;
import com.azure.core.util.IterableStream;
import org.slf4j.Logger;
Expand Down Expand Up @@ -137,15 +140,18 @@ public ChatResponse call(Prompt prompt) {
ChatCompletions chatCompletions = this.callWithFunctionSupport(options);
logger.trace("Azure ChatCompletions: {}", chatCompletions);

List<Generation> generations = chatCompletions.getChoices()
.stream()
.map(choice -> new Generation(choice.getMessage().getContent())
.withGenerationMetadata(generateChoiceMetadata(choice)))
.toList();
var id = chatCompletions.getId();

List<Generation> generations = chatCompletions.getChoices().stream().map(choice -> {
boolean isCompleted = choice.getFinishReason() != null || choice.getFinishDetails() != null;
int index = choice.getIndex();
return new Generation(id, index, isCompleted, choice.getMessage().getContent(), Map.of(),
generateChoiceMetadata(choice));
}).toList();

PromptMetadata promptFilterMetadata = generatePromptMetadata(chatCompletions);

return new ChatResponse(generations,
return new ChatResponse(id, generations,
AzureOpenAiChatResponseMetadata.from(chatCompletions, promptFilterMetadata));
}

Expand All @@ -162,12 +168,17 @@ public Flux<ChatResponse> stream(Prompt prompt) {
// Note: the first chat completions can be ignored when using Azure OpenAI
// service which is a known service bug.
.skip(1)
.map(ChatCompletions::getChoices)
.flatMap(List::stream)
.map(choice -> {
var content = (choice.getDelta() != null) ? choice.getDelta().getContent() : null;
var generation = new Generation(content).withGenerationMetadata(generateChoiceMetadata(choice));
return new ChatResponse(List.of(generation));
.map(completion -> {
var id = completion.getId();
var generations = completion.getChoices().stream().map(choice -> {
boolean isCompleted = choice.getFinishReason() != null || choice.getFinishDetails() != null;
int index = choice.getIndex();
var content = (choice.getDelta() != null) ? choice.getDelta().getContent() : null;
return new Generation(id, index, isCompleted, content, Map.of(), generateChoiceMetadata(choice));
}).toList();

return new ChatResponse(id, generations,
AzureOpenAiChatResponseMetadata.from(completion, generatePromptMetadata(completion)));
}));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ public class AzureOpenAiUsage implements Usage {

public static AzureOpenAiUsage from(ChatCompletions chatCompletions) {
Assert.notNull(chatCompletions, "ChatCompletions must not be null");
if (chatCompletions.getUsage() == null) {
return null;
}
return from(chatCompletions.getUsage());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,28 @@
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;

import com.azure.ai.openai.OpenAIClient;
import com.azure.ai.openai.OpenAIClientBuilder;
import com.azure.core.credential.AzureKeyCredential;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import reactor.core.publisher.Flux;

import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.messages.AssistantMessage;
import org.springframework.ai.parser.BeanOutputParser;
import org.springframework.ai.parser.ListOutputParser;
import org.springframework.ai.parser.MapOutputParser;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.chat.prompt.Prompt;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.chat.prompt.SystemPromptTemplate;
import org.springframework.ai.chat.messages.Message;
import org.springframework.ai.chat.messages.UserMessage;
import org.springframework.ai.parser.BeanOutputParser;
import org.springframework.ai.parser.ListOutputParser;
import org.springframework.ai.parser.MapOutputParser;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
Expand All @@ -51,6 +53,8 @@
@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+")
class AzureOpenAiChatClientIT {

private static final Logger logger = LoggerFactory.getLogger(AzureOpenAiChatClientIT.class);

@Autowired
private AzureOpenAiChatClient chatClient;

Expand All @@ -71,6 +75,17 @@ void roleTest() {
Prompt prompt = new Prompt(List.of(userMessage, systemMessage));
ChatResponse response = chatClient.call(prompt);
assertThat(response.getResult().getOutput().getContent()).contains("Blackbeard");

assertThat(response.getId()).isNotBlank();
var generation = response.getResults().get(0);

assertThat(generation.getIndex()).isEqualTo(0);
assertThat(generation.isCompleted()).isTrue();

AssistantMessage assistantMessage = generation.getOutput();
assertThat(assistantMessage.getId()).isEqualTo(response.getId());
assertThat(assistantMessage.getIndex()).isEqualTo(generation.getIndex());
assertThat(assistantMessage.isCompleted()).isTrue();
}

@Test
Expand Down Expand Up @@ -148,7 +163,9 @@ void beanOutputParserRecords() {
Generation generation = chatClient.call(prompt).getResult();

ActorsFilmsRecord actorsFilms = outputParser.parse(generation.getOutput().getContent());
System.out.println(actorsFilms);

logger.info("ActorsFilmsRecord: {}", actorsFilms);

assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
assertThat(actorsFilms.movies()).hasSize(5);
}
Expand All @@ -166,21 +183,61 @@ void beanStreamOutputParserRecords() {
PromptTemplate promptTemplate = new PromptTemplate(template, Map.of("format", format));
Prompt prompt = new Prompt(promptTemplate.createMessage());

String generationTextFromStream = chatClient.stream(prompt)
.collectList()
.block()
.stream()
Flux<ChatResponse> response = chatClient.stream(prompt);

List<ChatResponse> chatResponseList = response.collectList().block();

String generationTextFromStream = chatResponseList.stream()
.map(ChatResponse::getResults)
.flatMap(List::stream)
.map(Generation::getOutput)
.map(AssistantMessage::getContent)
.filter(Objects::nonNull)
.collect(Collectors.joining());

ActorsFilmsRecord actorsFilms = outputParser.parse(generationTextFromStream);
System.out.println(actorsFilms);

logger.info("ActorsFilmsRecord: {}", actorsFilms);

assertThat(actorsFilms.actor()).isEqualTo("Tom Hanks");
assertThat(actorsFilms.movies()).hasSize(5);

assertThat(chatResponseList).hasSizeGreaterThan(1);

var firstResponse = chatResponseList.get(0);

for (int i = 0; i < chatResponseList.size() - 1; i++) {
var responseX = chatResponseList.get(i);
assertThat(responseX.getId()).isEqualTo(firstResponse.getId());

assertThat(responseX.getResults()).hasSize(1);
var generation = responseX.getResults().get(0);

assertThat(generation.getId()).isEqualTo(firstResponse.getId());
assertThat(generation.getIndex()).isEqualTo(0);
assertThat(generation.isCompleted()).isFalse();

AssistantMessage assistantMessage = generation.getOutput();

assertThat(assistantMessage.getId()).isEqualTo(firstResponse.getId());
assertThat(assistantMessage.getIndex()).isEqualTo(0);
assertThat(assistantMessage.isCompleted()).isFalse();
}

var lastResponse = chatResponseList.get(chatResponseList.size() - 1);
assertThat(lastResponse.getId()).isEqualTo(firstResponse.getId());
assertThat(lastResponse.getResults()).hasSize(1);
var lastGeneration = lastResponse.getResults().get(0);

assertThat(lastGeneration.getId()).isEqualTo(firstResponse.getId());
assertThat(lastGeneration.getIndex()).isEqualTo(0);
assertThat(lastGeneration.isCompleted()).isTrue();

AssistantMessage lastAssistantMessage = lastGeneration.getOutput();

assertThat(lastAssistantMessage.getId()).isEqualTo(firstResponse.getId());
assertThat(lastAssistantMessage.getIndex()).isEqualTo(0);
assertThat(lastAssistantMessage.isCompleted()).isTrue();

}

@SpringBootConfiguration
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@
package org.springframework.ai.bedrock.anthropic;

import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.atomic.AtomicReference;

import org.springframework.ai.chat.ChatClient;
import org.springframework.ai.chat.prompt.ChatOptions;
import org.springframework.ai.chat.ChatResponse;
import org.springframework.ai.chat.metadata.ChatGenerationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;

import reactor.core.publisher.Flux;

import org.springframework.ai.bedrock.MessageToPromptConverter;
Expand Down Expand Up @@ -67,7 +72,13 @@ public ChatResponse call(Prompt prompt) {

AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request);

return new ChatResponse(List.of(new Generation(response.completion())));
var id = UUID.randomUUID().toString();
var text = response.completion();
var isCompleted = response.stopReason() != null;
var metadata = ChatGenerationMetadata.from(response.stopReason(), response.amazonBedrockInvocationMetrics());

return new ChatResponse(id, List.of(new Generation(id, 0, isCompleted, text, Map.of(), metadata)),
ChatResponseMetadata.NULL);
}

@Override
Expand All @@ -77,14 +88,23 @@ public Flux<ChatResponse> stream(Prompt prompt) {

Flux<AnthropicChatResponse> fluxResponse = this.anthropicChatApi.chatCompletionStream(request);

AtomicReference<String> idRef = new AtomicReference<>(UUID.randomUUID().toString());

return fluxResponse.map(response -> {

String id = idRef.get();
String stopReason = response.stopReason() != null ? response.stopReason() : null;
var generation = new Generation(response.completion());
if (response.amazonBedrockInvocationMetrics() != null) {
generation = generation.withGenerationMetadata(
ChatGenerationMetadata.from(stopReason, response.amazonBedrockInvocationMetrics()));
var text = response.completion();
var isCompleted = response.stopReason() != null;
var metadata = ChatGenerationMetadata.from(stopReason, response.amazonBedrockInvocationMetrics());

var generation = new Generation(id, 0, isCompleted, text, Map.of(), metadata);

if (isCompleted) {
idRef.set(UUID.randomUUID().toString());
}
return new ChatResponse(List.of(generation));

return new ChatResponse(id, List.of(generation), ChatResponseMetadata.NULL);
});
}

Expand Down