Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,6 @@
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.ai.util.json.JsonParser;
import org.springframework.core.retry.RetryException;
import org.springframework.core.retry.RetryTemplate;
import org.springframework.http.HttpHeaders;
import org.springframework.http.ResponseEntity;
Expand Down Expand Up @@ -194,19 +193,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
this.observationRegistry)
.observe(() -> {

ResponseEntity<ChatCompletionResponse> completionEntity = null;
try {
completionEntity = this.retryTemplate.execute(() -> this.anthropicApi.chatCompletionEntity(request,
this.getAdditionalHttpHeaders(prompt)));
}
catch (RetryException e) {
if (e.getCause() instanceof RuntimeException r) {
throw r;
}
else {
throw new RuntimeException(e.getCause());
}
}
ResponseEntity<ChatCompletionResponse> completionEntity = RetryUtils.execute(this.retryTemplate,
() -> this.anthropicApi.chatCompletionEntity(request, this.getAdditionalHttpHeaders(prompt)));

AnthropicApi.ChatCompletionResponse completionResponse = completionEntity.getBody();
AnthropicApi.Usage usage = completionResponse.usage();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.core.retry.RetryException;
import org.springframework.core.retry.RetryTemplate;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -166,18 +165,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
this.observationRegistry)
.observe(() -> {

ResponseEntity<ChatCompletion> completionEntity = null;
try {
completionEntity = this.retryTemplate.execute(() -> this.deepSeekApi.chatCompletionEntity(request));
}
catch (RetryException e) {
if (e.getCause() instanceof RuntimeException r) {
throw r;
}
else {
throw new RuntimeException(e.getCause());
}
}
ResponseEntity<ChatCompletion> completionEntity = RetryUtils.execute(this.retryTemplate,
() -> this.deepSeekApi.chatCompletionEntity(request));

var chatCompletion = completionEntity.getBody();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import org.springframework.ai.audio.tts.TextToSpeechResponse;
import org.springframework.ai.elevenlabs.api.ElevenLabsApi;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.core.retry.RetryException;
import org.springframework.core.retry.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
Expand Down Expand Up @@ -72,26 +71,15 @@ public static Builder builder() {
public TextToSpeechResponse call(TextToSpeechPrompt prompt) {
RequestContext requestContext = prepareRequest(prompt);

byte[] audioData = null;
try {
audioData = this.retryTemplate.execute(() -> {
var response = this.elevenLabsApi.textToSpeech(requestContext.request, requestContext.voiceId,
requestContext.queryParameters);
if (response.getBody() == null) {
logger.warn("No speech response returned for request: {}", requestContext.request);
return new byte[0];
}
return response.getBody();
});
}
catch (RetryException e) {
if (e.getCause() instanceof RuntimeException r) {
throw r;
}
else {
throw new RuntimeException(e.getCause());
byte[] audioData = RetryUtils.execute(this.retryTemplate, () -> {
var response = this.elevenLabsApi.textToSpeech(requestContext.request, requestContext.voiceId,
requestContext.queryParameters);
if (response.getBody() == null) {
logger.warn("No speech response returned for request: {}", requestContext.request);
return new byte[0];
}
}
return response.getBody();
});

return new TextToSpeechResponse(List.of(new Speech(audioData)));
}
Expand All @@ -100,19 +88,10 @@ public TextToSpeechResponse call(TextToSpeechPrompt prompt) {
public Flux<TextToSpeechResponse> stream(TextToSpeechPrompt prompt) {
RequestContext requestContext = prepareRequest(prompt);

try {
return this.retryTemplate.execute(() -> this.elevenLabsApi
.textToSpeechStream(requestContext.request, requestContext.voiceId, requestContext.queryParameters)
.map(entity -> new TextToSpeechResponse(List.of(new Speech(entity.getBody())))));
}
catch (RetryException e) {
if (e.getCause() instanceof RuntimeException r) {
throw r;
}
else {
throw new RuntimeException(e.getCause());
}
}
return RetryUtils.execute(this.retryTemplate,
() -> this.elevenLabsApi
.textToSpeechStream(requestContext.request, requestContext.voiceId, requestContext.queryParameters)
.map(entity -> new TextToSpeechResponse(List.of(new Speech(entity.getBody())))));
}

private RequestContext prepareRequest(TextToSpeechPrompt prompt) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.observation.conventions.AiProvider;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.core.retry.RetryException;
import org.springframework.core.retry.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
Expand Down Expand Up @@ -169,19 +168,8 @@ public EmbeddingResponse call(EmbeddingRequest request) {
}

// Call the embedding API with retry
EmbedContentResponse embeddingResponse = null;
try {
embeddingResponse = this.retryTemplate
.execute(() -> this.genAiClient.models.embedContent(modelName, validTexts, config));
}
catch (RetryException e) {
if (e.getCause() instanceof RuntimeException r) {
throw r;
}
else {
throw new RuntimeException(e.getCause());
}
}
EmbedContentResponse embeddingResponse = RetryUtils.execute(this.retryTemplate,
() -> this.genAiClient.models.embedContent(modelName, validTexts, config));

// Process the response
// Note: We need to handle the case where some texts were filtered out
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.beans.factory.DisposableBean;
import org.springframework.core.retry.RetryException;
import org.springframework.core.retry.RetryTemplate;
import org.springframework.lang.NonNull;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -406,39 +405,31 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
try {
return this.retryTemplate.execute(() -> {

var geminiRequest = createGeminiRequest(prompt);

GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest);

List<Generation> generations = generateContentResponse.candidates()
.orElse(List.of())
.stream()
.map(this::responseCandidateToGeneration)
.flatMap(List::stream)
.toList();

var usage = generateContentResponse.usageMetadata();
GoogleGenAiChatOptions options = (GoogleGenAiChatOptions) prompt.getOptions();
Usage currentUsage = (usage.isPresent()) ? getDefaultUsage(usage.get(), options)
: getDefaultUsage(null, options);
Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse);
ChatResponse chatResponse = new ChatResponse(generations,
toChatResponseMetadata(cumulativeUsage, generateContentResponse.modelVersion().get()));

observationContext.setResponse(chatResponse);
return chatResponse;
});
}
catch (RetryException e) {
if (e.getCause() instanceof RuntimeException r) {
throw r;
}

throw new RuntimeException(e);
}
return RetryUtils.execute(this.retryTemplate, () -> {

var geminiRequest = createGeminiRequest(prompt);

GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest);

List<Generation> generations = generateContentResponse.candidates()
.orElse(List.of())
.stream()
.map(this::responseCandidateToGeneration)
.flatMap(List::stream)
.toList();

var usage = generateContentResponse.usageMetadata();
GoogleGenAiChatOptions options = (GoogleGenAiChatOptions) prompt.getOptions();
Usage currentUsage = (usage.isPresent()) ? getDefaultUsage(usage.get(), options)
: getDefaultUsage(null, options);
Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse);
ChatResponse chatResponse = new ChatResponse(generations,
toChatResponseMetadata(cumulativeUsage, generateContentResponse.modelVersion().get()));

observationContext.setResponse(chatResponse);
return chatResponse;
});
});

if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.core.retry.RetryException;
import org.springframework.core.retry.RetryTemplate;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -254,18 +253,8 @@ public ChatResponse call(Prompt prompt) {
this.observationRegistry)
.observe(() -> {

ResponseEntity<ChatCompletion> completionEntity = null;
try {
completionEntity = this.retryTemplate.execute(() -> this.miniMaxApi.chatCompletionEntity(request));
}
catch (RetryException e) {
if (e.getCause() instanceof RuntimeException r) {
throw r;
}
else {
throw new RuntimeException(e.getCause());
}
}
ResponseEntity<ChatCompletion> completionEntity = RetryUtils.execute(this.retryTemplate,
() -> this.miniMaxApi.chatCompletionEntity(request));

var chatCompletion = completionEntity.getBody();

Expand Down Expand Up @@ -339,18 +328,8 @@ public Flux<ChatResponse> stream(Prompt prompt) {
return Flux.deferContextual(contextView -> {
ChatCompletionRequest request = createRequest(requestPrompt, true);

Flux<ChatCompletionChunk> completionChunks = null;
try {
completionChunks = this.retryTemplate.execute(() -> this.miniMaxApi.chatCompletionStream(request));
}
catch (RetryException e) {
if (e.getCause() instanceof RuntimeException r) {
throw r;
}
else {
throw new RuntimeException(e.getCause());
}
}
Flux<ChatCompletionChunk> completionChunks = RetryUtils.execute(this.retryTemplate,
() -> this.miniMaxApi.chatCompletionStream(request));

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
import org.springframework.ai.minimax.api.MiniMaxApiConstants;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.core.retry.RetryException;
import org.springframework.core.retry.RetryTemplate;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
Expand Down Expand Up @@ -166,19 +165,8 @@ public EmbeddingResponse call(EmbeddingRequest request) {
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
MiniMaxApi.EmbeddingList apiEmbeddingResponse = null;
try {
apiEmbeddingResponse = this.retryTemplate
.execute(() -> this.miniMaxApi.embeddings(apiRequest).getBody());
}
catch (RetryException e) {
if (e.getCause() instanceof RuntimeException r) {
throw r;
}
else {
throw new RuntimeException(e.getCause());
}
}
MiniMaxApi.EmbeddingList apiEmbeddingResponse = RetryUtils.execute(this.retryTemplate,
() -> this.miniMaxApi.embeddings(apiRequest).getBody());

if (apiEmbeddingResponse == null) {
logger.warn("No embeddings returned for request: {}", request);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
import org.springframework.ai.retry.RetryUtils;
import org.springframework.ai.support.UsageCalculator;
import org.springframework.ai.tool.definition.ToolDefinition;
import org.springframework.core.retry.RetryException;
import org.springframework.core.retry.RetryTemplate;
import org.springframework.http.ResponseEntity;
import org.springframework.util.Assert;
Expand Down Expand Up @@ -192,19 +191,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons
this.observationRegistry)
.observe(() -> {

ResponseEntity<ChatCompletion> completionEntity = null;
try {
completionEntity = this.retryTemplate
.execute(() -> this.mistralAiApi.chatCompletionEntity(request));
}
catch (RetryException e) {
if (e.getCause() instanceof RuntimeException r) {
throw r;
}
else {
throw new RuntimeException(e.getCause());
}
}
ResponseEntity<ChatCompletion> completionEntity = RetryUtils.execute(this.retryTemplate,
() -> this.mistralAiApi.chatCompletionEntity(request));

ChatCompletion chatCompletion = completionEntity.getBody();

Expand Down Expand Up @@ -276,18 +264,8 @@ public Flux<ChatResponse> internalStream(Prompt prompt, ChatResponse previousCha

observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start();

Flux<ChatCompletionChunk> completionChunks = null;
try {
completionChunks = this.retryTemplate.execute(() -> this.mistralAiApi.chatCompletionStream(request));
}
catch (RetryException e) {
if (e.getCause() instanceof RuntimeException r) {
throw r;
}
else {
throw new RuntimeException(e.getCause());
}
}
Flux<ChatCompletionChunk> completionChunks = RetryUtils.execute(this.retryTemplate,
() -> this.mistralAiApi.chatCompletionStream(request));

// For chunked responses, only the first chunk contains the choice role.
// The rest of the chunks with same ID share the same role.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
import org.springframework.ai.mistralai.api.MistralAiApi;
import org.springframework.ai.model.ModelOptionsUtils;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.core.retry.RetryException;
import org.springframework.core.retry.RetryTemplate;
import org.springframework.util.Assert;

Expand Down Expand Up @@ -118,19 +117,8 @@ public EmbeddingResponse call(EmbeddingRequest request) {
.observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext,
this.observationRegistry)
.observe(() -> {
MistralAiApi.EmbeddingList<MistralAiApi.Embedding> apiEmbeddingResponse = null;
try {
apiEmbeddingResponse = this.retryTemplate
.execute(() -> this.mistralAiApi.embeddings(apiRequest).getBody());
}
catch (RetryException e) {
if (e.getCause() instanceof RuntimeException r) {
throw r;
}
else {
throw new RuntimeException(e.getCause());
}
}
MistralAiApi.EmbeddingList<MistralAiApi.Embedding> apiEmbeddingResponse = RetryUtils
.execute(this.retryTemplate, () -> this.mistralAiApi.embeddings(apiRequest).getBody());

if (apiEmbeddingResponse == null) {
logger.warn("No embeddings returned for request: {}", request);
Expand Down
Loading