Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bedrock AI models add usage information. #605

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package org.springframework.ai.bedrock;

import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata;
import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetrics;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;
Expand All @@ -23,38 +24,45 @@
* {@link Usage} implementation for Bedrock API.
*
* @author Christian Tzolov
* @author Wei Jiang
* @since 0.8.0
*/
public class BedrockUsage implements Usage {

public static BedrockUsage from(AmazonBedrockInvocationMetrics usage) {
return new BedrockUsage(usage);
Assert.notNull(usage, "Amazon Bedrock Usage must not be null");

return new BedrockUsage(usage.inputTokenCount().longValue(), usage.outputTokenCount().longValue());
}

private final AmazonBedrockInvocationMetrics usage;
public static BedrockUsage from(AmazonBedrockInvocationMetadata metadata) {
Assert.notNull(metadata, "Amazon Bedrock Invocation Metadata must not be null");

protected BedrockUsage(AmazonBedrockInvocationMetrics usage) {
Assert.notNull(usage, "OpenAI Usage must not be null");
this.usage = usage;
return new BedrockUsage(metadata.inputTokenCount(), metadata.outputTokenCount());
}

protected AmazonBedrockInvocationMetrics getUsage() {
return this.usage;
private Long promptTokens;

private Long generationTokens;

protected BedrockUsage(Long promptTokens, Long generationTokens) {
this.promptTokens = promptTokens;
this.generationTokens = generationTokens;
}

@Override
public Long getPromptTokens() {
return getUsage().inputTokenCount().longValue();
return this.promptTokens;
}

@Override
public Long getGenerationTokens() {
return getUsage().outputTokenCount().longValue();
return this.generationTokens;
}

@Override
public String toString() {
return getUsage().toString();
return "BedrockUsage [promptTokens=" + promptTokens + ", generationTokens=" + generationTokens + "]";
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatRequest;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse;
import org.springframework.ai.bedrock.anthropic.metadata.BedrockAnthropicChatResponseMetadata;
import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationContext;
import org.springframework.ai.chat.StreamingChatClient;
import org.springframework.ai.chat.Generation;
import org.springframework.ai.chat.prompt.Prompt;
Expand All @@ -37,6 +39,7 @@
* generative.
*
* @author Christian Tzolov
* @author Wei Jiang
* @since 0.8.0
*/
public class BedrockAnthropicChatClient implements ChatClient, StreamingChatClient {
Expand Down Expand Up @@ -65,9 +68,16 @@ public ChatResponse call(Prompt prompt) {

AnthropicChatRequest request = createRequest(prompt);

AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request);
AmazonBedrockInvocationContext<AnthropicChatResponse> context = anthropicChatApi.chatCompletion(request);

return new ChatResponse(List.of(new Generation(response.completion())));
AnthropicChatResponse response = context.response();

List<Generation> generations = List.of(new Generation(response.completion()));

BedrockAnthropicChatResponseMetadata chatResponseMetadata = BedrockAnthropicChatResponseMetadata.from(response,
context.metadata());

return new ChatResponse(generations, chatResponseMetadata);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ public String id() {
}

@Override
public AnthropicChatResponse chatCompletion(AnthropicChatRequest anthropicRequest) {
public AmazonBedrockInvocationContext<AnthropicChatResponse> chatCompletion(AnthropicChatRequest anthropicRequest) {
Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null");
return this.internalInvocation(anthropicRequest, AnthropicChatResponse.class);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.anthropic.metadata;

import org.springframework.ai.bedrock.BedrockUsage;
import org.springframework.ai.bedrock.anthropic.api.AnthropicChatBedrockApi.AnthropicChatResponse;
import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;

/**
* {@link ChatResponseMetadata} implementation for
* {@literal Amazon Bedrock Anthropic Chat Model}.
*
* @author Wei Jiang
* @see ChatResponseMetadata
* @since 0.8.1
*/
public class BedrockAnthropicChatResponseMetadata implements ChatResponseMetadata {

protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, latency: %3$s, usage: %4$s, rateLimit: %5$s }";

public static BedrockAnthropicChatResponseMetadata from(AnthropicChatResponse response,
AmazonBedrockInvocationMetadata invocationMetadata) {
Assert.notNull(invocationMetadata, "Bedrock invocation metadata must not be null");

BedrockUsage usage = BedrockUsage.from(invocationMetadata);

BedrockAnthropicChatResponseMetadata chatResponseMetadata = new BedrockAnthropicChatResponseMetadata(
invocationMetadata.awsRequestId(), invocationMetadata.invocationLatency(), usage);
return chatResponseMetadata;
}

private final String id;

private Long invocationLatency;

private final Usage usage;

protected BedrockAnthropicChatResponseMetadata(String id, Long invocationLatency, BedrockUsage usage) {
this.id = id;
this.invocationLatency = invocationLatency;
this.usage = usage;
}

public String getId() {
return this.id;
}

public Long getInvocationLatency() {
return this.invocationLatency;
}

@Override
public Usage getUsage() {
return this.usage;
}

@Override
public PromptMetadata getPromptMetadata() {
return PromptMetadata.empty();
}

@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getInvocationLatency(), getUsage(),
getRateLimit());
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatStreamingResponse.StreamingType;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.MediaContent;
import org.springframework.ai.bedrock.anthropic3.metadata.BedrockAnthropic3ChatResponseMetadata;
import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationContext;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.ChatCompletionMessage.Role;
import org.springframework.ai.chat.ChatClient;
Expand Down Expand Up @@ -48,6 +50,7 @@
*
* @author Ben Middleton
* @author Christian Tzolov
* @author Wei Jiang
* @since 1.0.0
*/
public class BedrockAnthropic3ChatClient implements ChatClient, StreamingChatClient {
Expand Down Expand Up @@ -76,9 +79,16 @@ public ChatResponse call(Prompt prompt) {

AnthropicChatRequest request = createRequest(prompt);

AnthropicChatResponse response = this.anthropicChatApi.chatCompletion(request);
AmazonBedrockInvocationContext<AnthropicChatResponse> context = this.anthropicChatApi.chatCompletion(request);

return new ChatResponse(List.of(new Generation(response.content().get(0).text())));
AnthropicChatResponse response = context.response();

List<Generation> generations = List.of(new Generation(response.content().get(0).text()));

BedrockAnthropic3ChatResponseMetadata chatResponseMetadata = BedrockAnthropic3ChatResponseMetadata
.from(response, context.metadata());

return new ChatResponse(generations, chatResponseMetadata);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ public String id() {
}

@Override
public AnthropicChatResponse chatCompletion(AnthropicChatRequest anthropicRequest) {
public AmazonBedrockInvocationContext<AnthropicChatResponse> chatCompletion(AnthropicChatRequest anthropicRequest) {
Assert.notNull(anthropicRequest, "'anthropicRequest' must not be null");
return this.internalInvocation(anthropicRequest, AnthropicChatResponse.class);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright 2023 - 2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.bedrock.anthropic3.metadata;

import org.springframework.ai.bedrock.BedrockUsage;
import org.springframework.ai.bedrock.anthropic3.api.Anthropic3ChatBedrockApi.AnthropicChatResponse;
import org.springframework.ai.bedrock.api.AbstractBedrockApi.AmazonBedrockInvocationMetadata;
import org.springframework.ai.chat.metadata.ChatResponseMetadata;
import org.springframework.ai.chat.metadata.PromptMetadata;
import org.springframework.ai.chat.metadata.Usage;
import org.springframework.util.Assert;

/**
* {@link ChatResponseMetadata} implementation for
* {@literal Amazon Bedrock Anthropic Chat Model}.
*
* @author Wei Jiang
* @see ChatResponseMetadata
* @since 0.8.1
*/
public class BedrockAnthropic3ChatResponseMetadata implements ChatResponseMetadata {

protected static final String AI_METADATA_STRING = "{ @type: %1$s, id: %2$s, latency: %3$s, usage: %4$s, rateLimit: %5$s }";

public static BedrockAnthropic3ChatResponseMetadata from(AnthropicChatResponse response,
AmazonBedrockInvocationMetadata invocationMetadata) {
Assert.notNull(invocationMetadata, "Bedrock invocation metadata must not be null");

BedrockUsage usage = BedrockUsage.from(invocationMetadata);

BedrockAnthropic3ChatResponseMetadata chatResponseMetadata = new BedrockAnthropic3ChatResponseMetadata(
response.id(), invocationMetadata.invocationLatency(), usage);
return chatResponseMetadata;
}

private final String id;

private Long invocationLatency;

private final Usage usage;

protected BedrockAnthropic3ChatResponseMetadata(String id, Long invocationLatency, BedrockUsage usage) {
this.id = id;
this.invocationLatency = invocationLatency;
this.usage = usage;
}

public String getId() {
return this.id;
}

public Long getInvocationLatency() {
return this.invocationLatency;
}

@Override
public Usage getUsage() {
return this.usage;
}

@Override
public PromptMetadata getPromptMetadata() {
return PromptMetadata.empty();
}

@Override
public String toString() {
return AI_METADATA_STRING.formatted(getClass().getTypeName(), getId(), getInvocationLatency(), getUsage(),
getRateLimit());
}

}
Loading