diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java index 1d316f47200..b2a006dc4f6 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java @@ -66,7 +66,9 @@ public class AzureOpenAiAutoConfiguration { @Bean @ConditionalOnMissingBean // ({ OpenAIClient.class, TokenCredential.class }) - public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties connectionProperties) { + public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties connectionProperties, + ObjectProvider customizers) { + if (StringUtils.hasText(connectionProperties.getApiKey())) { Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty"); @@ -77,17 +79,21 @@ public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties c .map(entry -> new Header(entry.getKey(), entry.getValue())) .collect(Collectors.toList()); ClientOptions clientOptions = new ClientOptions().setApplicationId(APPLICATION_ID).setHeaders(headers); - return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) + OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) .credential(new AzureKeyCredential(connectionProperties.getApiKey())) .clientOptions(clientOptions); + applyOpenAIClientBuilderCustomizers(clientBuilder, customizers); + return clientBuilder; } // Connect to OpenAI (e.g. not the Azure OpenAI). The deploymentName property is // used as OpenAI model name. if (StringUtils.hasText(connectionProperties.getOpenAiApiKey())) { - return new OpenAIClientBuilder().endpoint("https://api.openai.com/v1") + OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint("https://api.openai.com/v1") .credential(new KeyCredential(connectionProperties.getOpenAiApiKey())) .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)); + applyOpenAIClientBuilderCustomizers(clientBuilder, customizers); + return clientBuilder; } throw new IllegalArgumentException("Either API key or OpenAI API key must not be empty"); @@ -97,14 +103,16 @@ public OpenAIClientBuilder openAIClientBuilder(AzureOpenAiConnectionProperties c @ConditionalOnMissingBean @ConditionalOnBean(TokenCredential.class) public OpenAIClientBuilder openAIClientWithTokenCredential(AzureOpenAiConnectionProperties connectionProperties, - TokenCredential tokenCredential) { + TokenCredential tokenCredential, ObjectProvider customizers) { Assert.notNull(tokenCredential, "TokenCredential must not be null"); Assert.hasText(connectionProperties.getEndpoint(), "Endpoint must not be empty"); - return new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) + OpenAIClientBuilder clientBuilder = new OpenAIClientBuilder().endpoint(connectionProperties.getEndpoint()) .credential(tokenCredential) .clientOptions(new ClientOptions().setApplicationId(APPLICATION_ID)); + applyOpenAIClientBuilderCustomizers(clientBuilder, customizers); + return clientBuilder; } @Bean @@ -169,4 +177,9 @@ public AzureOpenAiAudioTranscriptionModel azureOpenAiAudioTranscriptionModel(Ope return new AzureOpenAiAudioTranscriptionModel(openAIClient.buildClient(), audioProperties.getOptions()); } + private void applyOpenAIClientBuilderCustomizers(OpenAIClientBuilder clientBuilder, + ObjectProvider customizers) { + customizers.orderedStream().forEach(customizer -> customizer.customize(clientBuilder)); + } + } diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/OpenAIClientBuilderCustomizer.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/OpenAIClientBuilderCustomizer.java new file mode 100644 index 00000000000..40a87c034e2 --- /dev/null +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/OpenAIClientBuilderCustomizer.java @@ -0,0 +1,18 @@ +package org.springframework.ai.autoconfigure.azure.openai; + +import com.azure.ai.openai.OpenAIClientBuilder; + +/** + * Callback interface that can be implemented by beans wishing to customize the + * {@link OpenAIClientBuilder} whilst retaining the default auto-configuration. + */ +@FunctionalInterface +public interface OpenAIClientBuilderCustomizer { + + /** + * Customize the {@link OpenAIClientBuilder}. + * @param clientBuilder the {@link OpenAIClientBuilder} to customize + */ + void customize(OpenAIClientBuilder clientBuilder); + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java index 8af26e0b753..fc8bc473488 100644 --- a/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java +++ b/spring-ai-spring-boot-autoconfigure/src/test/java/org/springframework/ai/autoconfigure/azure/AzureOpenAiAutoConfigurationIT.java @@ -20,6 +20,7 @@ import java.net.URI; import java.util.List; import java.util.Map; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.stream.Collectors; import com.azure.ai.openai.OpenAIClient; @@ -33,6 +34,7 @@ import com.azure.core.http.HttpResponse; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.autoconfigure.azure.openai.OpenAIClientBuilderCustomizer; import reactor.core.publisher.Flux; import org.springframework.ai.autoconfigure.azure.openai.AzureOpenAiAutoConfiguration; @@ -228,4 +230,20 @@ void audioTranscriptionActivation() { .run(context -> assertThat(context.getBeansOfType(AzureOpenAiAudioTranscriptionModel.class)).isNotEmpty()); } + @Test + void openAIClientBuilderCustomizer() { + AtomicBoolean firstCustomizationApplied = new AtomicBoolean(false); + AtomicBoolean secondCustomizationApplied = new AtomicBoolean(false); + this.contextRunner + .withBean("first", OpenAIClientBuilderCustomizer.class, + () -> clientBuilder -> firstCustomizationApplied.set(true)) + .withBean("second", OpenAIClientBuilderCustomizer.class, + () -> clientBuilder -> secondCustomizationApplied.set(true)) + .run(context -> { + context.getBean(OpenAIClientBuilder.class); + assertThat(firstCustomizationApplied.get()).isTrue(); + assertThat(secondCustomizationApplied.get()).isTrue(); + }); + } + }