diff --git a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java index 1a7ea20f5fe..6cd8987e446 100644 --- a/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java +++ b/models/spring-ai-anthropic/src/main/java/org/springframework/ai/anthropic/AnthropicChatModel.java @@ -78,7 +78,6 @@ import org.springframework.ai.support.UsageCalculator; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.util.json.JsonParser; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.http.HttpHeaders; import org.springframework.http.ResponseEntity; @@ -194,19 +193,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity = null; - try { - completionEntity = this.retryTemplate.execute(() -> this.anthropicApi.chatCompletionEntity(request, - this.getAdditionalHttpHeaders(prompt))); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + ResponseEntity completionEntity = RetryUtils.execute(this.retryTemplate, + () -> this.anthropicApi.chatCompletionEntity(request, this.getAdditionalHttpHeaders(prompt))); AnthropicApi.ChatCompletionResponse completionResponse = completionEntity.getBody(); AnthropicApi.Usage usage = completionResponse.usage(); diff --git a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java index c89a1c0c812..1ddfc926001 100644 --- a/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java +++ b/models/spring-ai-deepseek/src/main/java/org/springframework/ai/deepseek/DeepSeekChatModel.java @@ -66,7 +66,6 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.support.UsageCalculator; import org.springframework.ai.tool.definition.ToolDefinition; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; @@ -166,18 +165,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity = null; - try { - completionEntity = this.retryTemplate.execute(() -> this.deepSeekApi.chatCompletionEntity(request)); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + ResponseEntity completionEntity = RetryUtils.execute(this.retryTemplate, + () -> this.deepSeekApi.chatCompletionEntity(request)); var chatCompletion = completionEntity.getBody(); diff --git a/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/ElevenLabsTextToSpeechModel.java b/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/ElevenLabsTextToSpeechModel.java index 187370a91c0..9db0ac6d83e 100644 --- a/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/ElevenLabsTextToSpeechModel.java +++ b/models/spring-ai-elevenlabs/src/main/java/org/springframework/ai/elevenlabs/ElevenLabsTextToSpeechModel.java @@ -28,7 +28,6 @@ import org.springframework.ai.audio.tts.TextToSpeechResponse; import org.springframework.ai.elevenlabs.api.ElevenLabsApi; import org.springframework.ai.retry.RetryUtils; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; @@ -72,26 +71,15 @@ public static Builder builder() { public TextToSpeechResponse call(TextToSpeechPrompt prompt) { RequestContext requestContext = prepareRequest(prompt); - byte[] audioData = null; - try { - audioData = this.retryTemplate.execute(() -> { - var response = this.elevenLabsApi.textToSpeech(requestContext.request, requestContext.voiceId, - requestContext.queryParameters); - if (response.getBody() == null) { - logger.warn("No speech response returned for request: {}", requestContext.request); - return new byte[0]; - } - return response.getBody(); - }); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); + byte[] audioData = RetryUtils.execute(this.retryTemplate, () -> { + var response = this.elevenLabsApi.textToSpeech(requestContext.request, requestContext.voiceId, + requestContext.queryParameters); + if (response.getBody() == null) { + logger.warn("No speech response returned for request: {}", requestContext.request); + return new byte[0]; } - } + return response.getBody(); + }); return new TextToSpeechResponse(List.of(new Speech(audioData))); } @@ -100,19 +88,10 @@ public TextToSpeechResponse call(TextToSpeechPrompt prompt) { public Flux stream(TextToSpeechPrompt prompt) { RequestContext requestContext = prepareRequest(prompt); - try { - return this.retryTemplate.execute(() -> this.elevenLabsApi - .textToSpeechStream(requestContext.request, requestContext.voiceId, requestContext.queryParameters) - .map(entity -> new TextToSpeechResponse(List.of(new Speech(entity.getBody()))))); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + return RetryUtils.execute(this.retryTemplate, + () -> this.elevenLabsApi + .textToSpeechStream(requestContext.request, requestContext.voiceId, requestContext.queryParameters) + .map(entity -> new TextToSpeechResponse(List.of(new Speech(entity.getBody()))))); } private RequestContext prepareRequest(TextToSpeechPrompt prompt) { diff --git a/models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java b/models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java index dba743af21e..629d2ebe2bb 100644 --- a/models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java +++ b/models/spring-ai-google-genai-embedding/src/main/java/org/springframework/ai/google/genai/text/GoogleGenAiTextEmbeddingModel.java @@ -46,7 +46,6 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.ai.retry.RetryUtils; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -169,19 +168,8 @@ public EmbeddingResponse call(EmbeddingRequest request) { } // Call the embedding API with retry - EmbedContentResponse embeddingResponse = null; - try { - embeddingResponse = this.retryTemplate - .execute(() -> this.genAiClient.models.embedContent(modelName, validTexts, config)); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + EmbedContentResponse embeddingResponse = RetryUtils.execute(this.retryTemplate, + () -> this.genAiClient.models.embedContent(modelName, validTexts, config)); // Process the response // Note: We need to handle the case where some texts were filtered out diff --git a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java index 70f5c1385dc..c5c30791670 100644 --- a/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java +++ b/models/spring-ai-google-genai/src/main/java/org/springframework/ai/google/genai/GoogleGenAiChatModel.java @@ -87,7 +87,6 @@ import org.springframework.ai.support.UsageCalculator; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.beans.factory.DisposableBean; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.lang.NonNull; import org.springframework.util.Assert; @@ -406,39 +405,31 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - try { - return this.retryTemplate.execute(() -> { - - var geminiRequest = createGeminiRequest(prompt); - - GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest); - - List generations = generateContentResponse.candidates() - .orElse(List.of()) - .stream() - .map(this::responseCandidateToGeneration) - .flatMap(List::stream) - .toList(); - - var usage = generateContentResponse.usageMetadata(); - GoogleGenAiChatOptions options = (GoogleGenAiChatOptions) prompt.getOptions(); - Usage currentUsage = (usage.isPresent()) ? getDefaultUsage(usage.get(), options) - : getDefaultUsage(null, options); - Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse); - ChatResponse chatResponse = new ChatResponse(generations, - toChatResponseMetadata(cumulativeUsage, generateContentResponse.modelVersion().get())); - - observationContext.setResponse(chatResponse); - return chatResponse; - }); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - throw new RuntimeException(e); - } + return RetryUtils.execute(this.retryTemplate, () -> { + + var geminiRequest = createGeminiRequest(prompt); + + GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest); + + List generations = generateContentResponse.candidates() + .orElse(List.of()) + .stream() + .map(this::responseCandidateToGeneration) + .flatMap(List::stream) + .toList(); + + var usage = generateContentResponse.usageMetadata(); + GoogleGenAiChatOptions options = (GoogleGenAiChatOptions) prompt.getOptions(); + Usage currentUsage = (usage.isPresent()) ? getDefaultUsage(usage.get(), options) + : getDefaultUsage(null, options); + Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse); + ChatResponse chatResponse = new ChatResponse(generations, + toChatResponseMetadata(cumulativeUsage, generateContentResponse.modelVersion().get())); + + observationContext.setResponse(chatResponse); + return chatResponse; + }); }); if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java index fa72c428c86..721ed9c926c 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxChatModel.java @@ -68,7 +68,6 @@ import org.springframework.ai.model.tool.internal.ToolCallReactiveContextHolder; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.tool.definition.ToolDefinition; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; @@ -254,18 +253,8 @@ public ChatResponse call(Prompt prompt) { this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity = null; - try { - completionEntity = this.retryTemplate.execute(() -> this.miniMaxApi.chatCompletionEntity(request)); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + ResponseEntity completionEntity = RetryUtils.execute(this.retryTemplate, + () -> this.miniMaxApi.chatCompletionEntity(request)); var chatCompletion = completionEntity.getBody(); @@ -339,18 +328,8 @@ public Flux stream(Prompt prompt) { return Flux.deferContextual(contextView -> { ChatCompletionRequest request = createRequest(requestPrompt, true); - Flux completionChunks = null; - try { - completionChunks = this.retryTemplate.execute(() -> this.miniMaxApi.chatCompletionStream(request)); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + Flux completionChunks = RetryUtils.execute(this.retryTemplate, + () -> this.miniMaxApi.chatCompletionStream(request)); // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. diff --git a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java index ab631ee1825..14b27e4895c 100644 --- a/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java +++ b/models/spring-ai-minimax/src/main/java/org/springframework/ai/minimax/MiniMaxEmbeddingModel.java @@ -40,7 +40,6 @@ import org.springframework.ai.minimax.api.MiniMaxApiConstants; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -166,19 +165,8 @@ public EmbeddingResponse call(EmbeddingRequest request) { .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - MiniMaxApi.EmbeddingList apiEmbeddingResponse = null; - try { - apiEmbeddingResponse = this.retryTemplate - .execute(() -> this.miniMaxApi.embeddings(apiRequest).getBody()); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + MiniMaxApi.EmbeddingList apiEmbeddingResponse = RetryUtils.execute(this.retryTemplate, + () -> this.miniMaxApi.embeddings(apiRequest).getBody()); if (apiEmbeddingResponse == null) { logger.warn("No embeddings returned for request: {}", request); diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java index eb1c77481b6..431c25afc19 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiChatModel.java @@ -69,7 +69,6 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.support.UsageCalculator; import org.springframework.ai.tool.definition.ToolDefinition; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; @@ -192,19 +191,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity = null; - try { - completionEntity = this.retryTemplate - .execute(() -> this.mistralAiApi.chatCompletionEntity(request)); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + ResponseEntity completionEntity = RetryUtils.execute(this.retryTemplate, + () -> this.mistralAiApi.chatCompletionEntity(request)); ChatCompletion chatCompletion = completionEntity.getBody(); @@ -276,18 +264,8 @@ public Flux internalStream(Prompt prompt, ChatResponse previousCha observation.parentObservation(contextView.getOrDefault(ObservationThreadLocalAccessor.KEY, null)).start(); - Flux completionChunks = null; - try { - completionChunks = this.retryTemplate.execute(() -> this.mistralAiApi.chatCompletionStream(request)); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + Flux completionChunks = RetryUtils.execute(this.retryTemplate, + () -> this.mistralAiApi.chatCompletionStream(request)); // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java index 5e77f6413ca..ab22e6a1c45 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/MistralAiEmbeddingModel.java @@ -39,7 +39,6 @@ import org.springframework.ai.mistralai.api.MistralAiApi; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; @@ -118,19 +117,8 @@ public EmbeddingResponse call(EmbeddingRequest request) { .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - MistralAiApi.EmbeddingList apiEmbeddingResponse = null; - try { - apiEmbeddingResponse = this.retryTemplate - .execute(() -> this.mistralAiApi.embeddings(apiRequest).getBody()); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + MistralAiApi.EmbeddingList apiEmbeddingResponse = RetryUtils + .execute(this.retryTemplate, () -> this.mistralAiApi.embeddings(apiRequest).getBody()); if (apiEmbeddingResponse == null) { logger.warn("No embeddings returned for request: {}", request); diff --git a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationModel.java b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationModel.java index 1eb5d7ad263..e80166ff342 100644 --- a/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationModel.java +++ b/models/spring-ai-mistral-ai/src/main/java/org/springframework/ai/mistralai/moderation/MistralAiModerationModel.java @@ -23,6 +23,9 @@ import org.slf4j.LoggerFactory; import org.springframework.ai.mistralai.api.MistralAiModerationApi; +import org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationRequest; +import org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResponse; +import org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResult; import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.moderation.Categories; import org.springframework.ai.moderation.CategoryScores; @@ -34,15 +37,10 @@ import org.springframework.ai.moderation.ModerationResponse; import org.springframework.ai.moderation.ModerationResult; import org.springframework.ai.retry.RetryUtils; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; -import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationRequest; -import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResponse; -import static org.springframework.ai.mistralai.api.MistralAiModerationApi.MistralAiModerationResult; - /** * @author Ricken Bazolo * @author Jason Smith @@ -69,39 +67,29 @@ public MistralAiModerationModel(MistralAiModerationApi mistralAiModerationApi, R @Override public ModerationResponse call(ModerationPrompt moderationPrompt) { - try { - return this.retryTemplate.execute(() -> { - - var instructions = moderationPrompt.getInstructions().getText(); - var moderationRequest = new MistralAiModerationRequest(instructions); + return RetryUtils.execute(this.retryTemplate, () -> { - if (this.defaultOptions != null) { - moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest, - MistralAiModerationRequest.class); - } - else { - // moderationPrompt.getOptions() never null but model can be empty, - // cause - // by ModerationPrompt constructor - moderationRequest = ModelOptionsUtils.merge( - toMistralAiModerationOptions(moderationPrompt.getOptions()), moderationRequest, - MistralAiModerationRequest.class); - } + var instructions = moderationPrompt.getInstructions().getText(); - var moderationResponseEntity = this.mistralAiModerationApi.moderate(moderationRequest); + var moderationRequest = new MistralAiModerationRequest(instructions); - return convertResponse(moderationResponseEntity, moderationRequest); - }); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; + if (this.defaultOptions != null) { + moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest, + MistralAiModerationRequest.class); } else { - throw new RuntimeException(e.getCause()); + // moderationPrompt.getOptions() never null but model can be empty, + // cause + // by ModerationPrompt constructor + moderationRequest = ModelOptionsUtils.merge(toMistralAiModerationOptions(moderationPrompt.getOptions()), + moderationRequest, MistralAiModerationRequest.class); } - } + + var moderationResponseEntity = this.mistralAiModerationApi.moderate(moderationRequest); + + return convertResponse(moderationResponseEntity, moderationRequest); + }); } private ModerationResponse convertResponse(ResponseEntity moderationResponseEntity, diff --git a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java index 8da11283565..a0a4f09dec2 100644 --- a/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java +++ b/models/spring-ai-ollama/src/main/java/org/springframework/ai/ollama/OllamaChatModel.java @@ -69,7 +69,6 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.tool.definition.ToolDefinition; import org.springframework.ai.util.json.JsonParser; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -246,18 +245,8 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon this.observationRegistry) .observe(() -> { - OllamaApi.ChatResponse ollamaResponse = null; - try { - ollamaResponse = this.retryTemplate.execute(() -> this.chatApi.chat(request)); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + OllamaApi.ChatResponse ollamaResponse = RetryUtils.execute(this.retryTemplate, + () -> this.chatApi.chat(request)); List toolCalls = ollamaResponse.message().toolCalls() == null ? List.of() : ollamaResponse.message() diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java index b934d717391..3e0e36cc825 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiAudioSpeechModel.java @@ -129,13 +129,8 @@ public TextToSpeechResponse call(TextToSpeechPrompt prompt) { OpenAiAudioApi.SpeechRequest speechRequest = createRequest(prompt); - ResponseEntity speechEntity; - try { - speechEntity = this.retryTemplate.execute(() -> this.audioApi.createSpeech(speechRequest)); - } - catch (Exception e) { - throw new RuntimeException("Error calling OpenAI audio speech API", e); - } + ResponseEntity speechEntity = RetryUtils.execute(this.retryTemplate, + () -> this.audioApi.createSpeech(speechRequest)); var speech = speechEntity.getBody(); @@ -161,13 +156,8 @@ public Flux stream(TextToSpeechPrompt prompt) { OpenAiAudioApi.SpeechRequest speechRequest = createRequest(prompt); - Flux> speechEntity; - try { - speechEntity = this.retryTemplate.execute(() -> this.audioApi.stream(speechRequest)); - } - catch (Exception e) { - throw new RuntimeException("Error calling OpenAI audio speech streaming API", e); - } + Flux> speechEntity = RetryUtils.execute(this.retryTemplate, + () -> this.audioApi.stream(speechRequest)); return speechEntity.map(entity -> new TextToSpeechResponse(List.of(new Speech(entity.getBody())), new OpenAiAudioSpeechResponseMetadata(OpenAiResponseHeaderExtractor.extractAiResponseHeaders(entity)))); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java index 54f940e5fc5..a53cbc64b46 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiChatModel.java @@ -195,14 +195,8 @@ public ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespons this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity; - try { - completionEntity = this.retryTemplate - .execute(() -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt))); - } - catch (Exception e) { - throw new RuntimeException("Error calling OpenAI chat completion API", e); - } + ResponseEntity completionEntity = RetryUtils.execute(this.retryTemplate, + () -> this.openAiApi.chatCompletionEntity(request, getAdditionalHttpHeaders(prompt))); var chatCompletion = completionEntity.getBody(); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java index d5c3cb347d4..df13c3ced50 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java @@ -164,14 +164,8 @@ public EmbeddingResponse call(EmbeddingRequest request) { .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - EmbeddingList apiEmbeddingResponse; - try { - apiEmbeddingResponse = this.retryTemplate - .execute(() -> this.openAiApi.embeddings(apiRequest).getBody()); - } - catch (Exception e) { - throw new RuntimeException("Error calling OpenAI embedding API", e); - } + EmbeddingList apiEmbeddingResponse = RetryUtils.execute(this.retryTemplate, + () -> this.openAiApi.embeddings(apiRequest).getBody()); if (apiEmbeddingResponse == null) { logger.warn("No embeddings returned for request: {}", request); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java index d850c6ceef2..0e791ef8a5f 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java @@ -141,14 +141,8 @@ public ImageResponse call(ImagePrompt imagePrompt) { .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - ResponseEntity imageResponseEntity; - try { - imageResponseEntity = this.retryTemplate - .execute(() -> this.openAiImageApi.createImage(imageRequest)); - } - catch (Exception e) { - throw new RuntimeException("Error calling OpenAI image API", e); - } + ResponseEntity imageResponseEntity = RetryUtils + .execute(this.retryTemplate, () -> this.openAiImageApi.createImage(imageRequest)); ImageResponse imageResponse = convertResponse(imageResponseEntity, imageRequest); diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java index 5a22be70ad0..6f0db4242c4 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiModerationModel.java @@ -77,34 +77,28 @@ public OpenAiModerationModel withDefaultOptions(OpenAiModerationOptions defaultO @Override public ModerationResponse call(ModerationPrompt moderationPrompt) { - try { - return this.retryTemplate.execute(() -> { + return RetryUtils.execute(this.retryTemplate, () -> { - String instructions = moderationPrompt.getInstructions().getText(); + String instructions = moderationPrompt.getInstructions().getText(); - OpenAiModerationApi.OpenAiModerationRequest moderationRequest = new OpenAiModerationApi.OpenAiModerationRequest( - instructions); + OpenAiModerationApi.OpenAiModerationRequest moderationRequest = new OpenAiModerationApi.OpenAiModerationRequest( + instructions); - if (this.defaultOptions != null) { - moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest, - OpenAiModerationApi.OpenAiModerationRequest.class); - } + if (this.defaultOptions != null) { + moderationRequest = ModelOptionsUtils.merge(this.defaultOptions, moderationRequest, + OpenAiModerationApi.OpenAiModerationRequest.class); + } - if (moderationPrompt.getOptions() != null) { - moderationRequest = ModelOptionsUtils.merge( - toOpenAiModerationOptions(moderationPrompt.getOptions()), moderationRequest, - OpenAiModerationApi.OpenAiModerationRequest.class); - } + if (moderationPrompt.getOptions() != null) { + moderationRequest = ModelOptionsUtils.merge(toOpenAiModerationOptions(moderationPrompt.getOptions()), + moderationRequest, OpenAiModerationApi.OpenAiModerationRequest.class); + } - ResponseEntity moderationResponseEntity = this.openAiModerationApi - .createModeration(moderationRequest); + ResponseEntity moderationResponseEntity = this.openAiModerationApi + .createModeration(moderationRequest); - return convertResponse(moderationResponseEntity, moderationRequest); - }); - } - catch (Exception e) { - throw new RuntimeException("Error calling OpenAI moderation API", e); - } + return convertResponse(moderationResponseEntity, moderationRequest); + }); } private ModerationResponse convertResponse( diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java index 380ed5035ae..da506b30bc0 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModel.java @@ -50,7 +50,6 @@ import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextInstanceBuilder; import org.springframework.ai.vertexai.embedding.VertexAiEmbeddingUtils.TextParametersBuilder; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.util.Assert; import org.springframework.util.StringUtils; @@ -139,8 +138,8 @@ public EmbeddingResponse call(EmbeddingRequest request) { PredictRequest.Builder predictRequestBuilder = getPredictRequestBuilder(request, endpointName, (VertexAiTextEmbeddingOptions) options); - PredictResponse embeddingResponse = this.retryTemplate - .execute(() -> getPredictResponse(client, predictRequestBuilder)); + PredictResponse embeddingResponse = RetryUtils.execute(this.retryTemplate, + () -> getPredictResponse(client, predictRequestBuilder)); int index = 0; int totalTokenCount = 0; @@ -164,14 +163,6 @@ public EmbeddingResponse call(EmbeddingRequest request) { return response; } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } }); } diff --git a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java index c7b95ddad88..a32dd3fb5d9 100644 --- a/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java +++ b/models/spring-ai-vertex-ai-gemini/src/main/java/org/springframework/ai/vertexai/gemini/VertexAiGeminiChatModel.java @@ -91,7 +91,6 @@ import org.springframework.ai.vertexai.gemini.schema.VertexAiSchemaConverter; import org.springframework.ai.vertexai.gemini.schema.VertexToolCallingManager; import org.springframework.beans.factory.DisposableBean; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.lang.NonNull; import org.springframework.util.Assert; @@ -391,43 +390,29 @@ private ChatResponse internalCall(Prompt prompt, ChatResponse previousChatRespon .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - try { - return this.retryTemplate.execute(() -> { - - var geminiRequest = createGeminiRequest(prompt); - - GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest); - - List generations = generateContentResponse.getCandidatesList() - .stream() - .map(this::responseCandidateToGeneration) - .flatMap(List::stream) - .toList(); - - GenerateContentResponse.UsageMetadata usage = generateContentResponse.getUsageMetadata(); - Usage currentUsage = (usage != null) - ? new DefaultUsage(usage.getPromptTokenCount(), usage.getCandidatesTokenCount()) - : new EmptyUsage(); - Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse); - ChatResponse chatResponse = new ChatResponse(generations, - toChatResponseMetadata(cumulativeUsage)); - - observationContext.setResponse(chatResponse); - return chatResponse; - }); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + return RetryUtils.execute(this.retryTemplate, () -> { + + var geminiRequest = createGeminiRequest(prompt); + + GenerateContentResponse generateContentResponse = this.getContentResponse(geminiRequest); + + List generations = generateContentResponse.getCandidatesList() + .stream() + .map(this::responseCandidateToGeneration) + .flatMap(List::stream) + .toList(); + + GenerateContentResponse.UsageMetadata usage = generateContentResponse.getUsageMetadata(); + Usage currentUsage = (usage != null) + ? new DefaultUsage(usage.getPromptTokenCount(), usage.getCandidatesTokenCount()) + : new EmptyUsage(); + Usage cumulativeUsage = UsageCalculator.getCumulativeUsage(currentUsage, previousChatResponse); + ChatResponse chatResponse = new ChatResponse(generations, toChatResponseMetadata(cumulativeUsage)); + + observationContext.setResponse(chatResponse); + return chatResponse; + }); }); if (this.toolExecutionEligibilityPredicate.isToolExecutionRequired(prompt.getOptions(), response)) { diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java index 4c6e0225bf9..8086b48c4a9 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiChatModel.java @@ -70,7 +70,6 @@ import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionMessage.ToolCall; import org.springframework.ai.zhipuai.api.ZhiPuAiApi.ChatCompletionRequest; import org.springframework.ai.zhipuai.api.ZhiPuApiConstants; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; @@ -261,18 +260,8 @@ public ChatResponse call(Prompt prompt) { this.observationRegistry) .observe(() -> { - ResponseEntity completionEntity = null; - try { - completionEntity = this.retryTemplate.execute(() -> this.zhiPuAiApi.chatCompletionEntity(request)); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + ResponseEntity completionEntity = RetryUtils.execute(this.retryTemplate, + () -> this.zhiPuAiApi.chatCompletionEntity(request)); var chatCompletion = completionEntity.getBody(); @@ -330,18 +319,8 @@ public Flux stream(Prompt prompt) { Prompt requestPrompt = buildRequestPrompt(prompt); ChatCompletionRequest request = createRequest(requestPrompt, true); - Flux completionChunks = null; - try { - completionChunks = this.retryTemplate.execute(() -> this.zhiPuAiApi.chatCompletionStream(request)); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + Flux completionChunks = RetryUtils.execute(this.retryTemplate, + () -> this.zhiPuAiApi.chatCompletionStream(request)); // For chunked responses, only the first chunk contains the choice role. // The rest of the chunks with same ID share the same role. diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java index 7a60ddb883c..a98ccf92d87 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiEmbeddingModel.java @@ -41,7 +41,6 @@ import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.zhipuai.api.ZhiPuAiApi; import org.springframework.ai.zhipuai.api.ZhiPuApiConstants; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; @@ -167,19 +166,8 @@ public EmbeddingResponse call(EmbeddingRequest request) { .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, this.observationRegistry) .observe(() -> { - ResponseEntity> embeddingResponse = null; - try { - embeddingResponse = this.retryTemplate - .execute(() -> this.zhiPuAiApi.embeddings(zhipuEmbeddingRequest)); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; - } - else { - throw new RuntimeException(e.getCause()); - } - } + ResponseEntity> embeddingResponse = RetryUtils + .execute(this.retryTemplate, () -> this.zhiPuAiApi.embeddings(zhipuEmbeddingRequest)); if (embeddingResponse == null || embeddingResponse.getBody() == null || CollectionUtils.isEmpty(embeddingResponse.getBody().data())) { diff --git a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java index e88231a9929..d3c4d124c66 100644 --- a/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java +++ b/models/spring-ai-zhipuai/src/main/java/org/springframework/ai/zhipuai/ZhiPuAiImageModel.java @@ -30,7 +30,6 @@ import org.springframework.ai.model.ModelOptionsUtils; import org.springframework.ai.retry.RetryUtils; import org.springframework.ai.zhipuai.api.ZhiPuAiImageApi; -import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryTemplate; import org.springframework.http.ResponseEntity; import org.springframework.util.Assert; @@ -72,40 +71,31 @@ public ZhiPuAiImageOptions getDefaultOptions() { @Override public ImageResponse call(ImagePrompt imagePrompt) { - try { - return this.retryTemplate.execute(() -> { - String instructions = imagePrompt.getInstructions().get(0).getText(); + return RetryUtils.execute(this.retryTemplate, () -> { - ZhiPuAiImageApi.ZhiPuAiImageRequest imageRequest = new ZhiPuAiImageApi.ZhiPuAiImageRequest(instructions, - ZhiPuAiImageApi.DEFAULT_IMAGE_MODEL); + String instructions = imagePrompt.getInstructions().get(0).getText(); - if (this.defaultOptions != null) { - imageRequest = ModelOptionsUtils.merge(this.defaultOptions, imageRequest, - ZhiPuAiImageApi.ZhiPuAiImageRequest.class); - } - - if (imagePrompt.getOptions() != null) { - imageRequest = ModelOptionsUtils.merge(toZhiPuAiImageOptions(imagePrompt.getOptions()), - imageRequest, ZhiPuAiImageApi.ZhiPuAiImageRequest.class); - } + ZhiPuAiImageApi.ZhiPuAiImageRequest imageRequest = new ZhiPuAiImageApi.ZhiPuAiImageRequest(instructions, + ZhiPuAiImageApi.DEFAULT_IMAGE_MODEL); - // Make the request - ResponseEntity imageResponseEntity = this.zhiPuAiImageApi - .createImage(imageRequest); - - // Convert to org.springframework.ai.model derived ImageResponse data type - return convertResponse(imageResponseEntity, imageRequest); - }); - } - catch (RetryException e) { - if (e.getCause() instanceof RuntimeException r) { - throw r; + if (this.defaultOptions != null) { + imageRequest = ModelOptionsUtils.merge(this.defaultOptions, imageRequest, + ZhiPuAiImageApi.ZhiPuAiImageRequest.class); } - else { - throw new RuntimeException(e.getCause()); + + if (imagePrompt.getOptions() != null) { + imageRequest = ModelOptionsUtils.merge(toZhiPuAiImageOptions(imagePrompt.getOptions()), imageRequest, + ZhiPuAiImageApi.ZhiPuAiImageRequest.class); } - } + + // Make the request + ResponseEntity imageResponseEntity = this.zhiPuAiImageApi + .createImage(imageRequest); + + // Convert to org.springframework.ai.model derived ImageResponse data type + return convertResponse(imageResponseEntity, imageRequest); + }); } private ImageResponse convertResponse(ResponseEntity imageResponseEntity, diff --git a/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java b/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java index 2a1f3c97c73..24acdf5987e 100644 --- a/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java +++ b/spring-ai-retry/src/main/java/org/springframework/ai/retry/RetryUtils.java @@ -22,9 +22,11 @@ import java.time.Duration; import java.util.concurrent.atomic.AtomicInteger; +import org.jspecify.annotations.Nullable; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.core.retry.RetryException; import org.springframework.core.retry.RetryListener; import org.springframework.core.retry.RetryPolicy; import org.springframework.core.retry.RetryTemplate; @@ -155,4 +157,21 @@ public void onRetryFailure(final RetryPolicy policy, final Retryable retryabl return retryTemplate; } + /** + * Generic execute method to run retryable operations with the provided RetryTemplate. + * @param the return type + * @param retryTemplate the RetryTemplate to use for executing the retryable operation + * @param retryable the operation to be retried + * @return the result of the retryable operation + */ + public static R execute(RetryTemplate retryTemplate, Retryable retryable) { + try { + return retryTemplate.execute(retryable); + } + catch (RetryException e) { + throw (e.getCause() instanceof RuntimeException runtime) ? runtime + : new RuntimeException(e.getMessage(), e.getCause()); + } + } + }