diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java index da419e4ca50..7871b305e92 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiChatModel.java @@ -687,4 +687,13 @@ private ChatCompletionsResponseFormat toAzureResponseFormat(AzureOpenAiResponseF return new ChatCompletionsTextResponseFormat(); } + /** + * Use the provided convention for reporting observation data + * @param observationConvention The provided convention + */ + public void setObservationConvention(ChatModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "observationConvention cannot be null"); + this.observationConvention = observationConvention; + } + } diff --git a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java index 546c4ed29ed..17827585124 100644 --- a/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java +++ b/models/spring-ai-azure-openai/src/main/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModel.java @@ -19,6 +19,9 @@ import com.azure.ai.openai.models.EmbeddingItem; import com.azure.ai.openai.models.Embeddings; import com.azure.ai.openai.models.EmbeddingsOptions; + +import io.micrometer.observation.ObservationRegistry; + import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.ai.azure.openai.metadata.AzureOpenAiEmbeddingUsage; @@ -29,7 +32,12 @@ import org.springframework.ai.embedding.EmbeddingRequest; import org.springframework.ai.embedding.EmbeddingResponse; import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationContext; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation; import org.springframework.ai.model.EmbeddingUtils; +import org.springframework.ai.observation.conventions.AiProvider; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; @@ -54,6 +62,18 @@ public class AzureOpenAiEmbeddingModel extends AbstractEmbeddingModel { private final MetadataMode metadataMode; + private static final EmbeddingModelObservationConvention DEFAULT_OBSERVATION_CONVENTION = new DefaultEmbeddingModelObservationConvention(); + + /** + * Observation registry used for instrumentation. + */ + private final ObservationRegistry observationRegistry; + + /** + * Conventions to use for generating observations. + */ + private EmbeddingModelObservationConvention observationConvention = DEFAULT_OBSERVATION_CONVENTION; + public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient) { this(azureOpenAiClient, MetadataMode.EMBED); } @@ -65,12 +85,20 @@ public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode me public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode, AzureOpenAiEmbeddingOptions options) { + this(azureOpenAiClient, metadataMode, options, ObservationRegistry.NOOP); + } + + public AzureOpenAiEmbeddingModel(OpenAIClient azureOpenAiClient, MetadataMode metadataMode, + AzureOpenAiEmbeddingOptions options, ObservationRegistry observationRegistry) { + Assert.notNull(azureOpenAiClient, "com.azure.ai.openai.OpenAIClient must not be null"); Assert.notNull(metadataMode, "Metadata mode must not be null"); Assert.notNull(options, "Options must not be null"); + Assert.notNull(observationRegistry, "Observation registry must not be null"); this.azureOpenAiClient = azureOpenAiClient; this.metadataMode = metadataMode; this.defaultOptions = options; + this.observationRegistry = observationRegistry; } @Override @@ -91,11 +119,29 @@ public float[] embed(Document document) { public EmbeddingResponse call(EmbeddingRequest embeddingRequest) { logger.debug("Retrieving embeddings"); - EmbeddingsOptions azureOptions = toEmbeddingOptions(embeddingRequest); - Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(azureOptions.getModel(), azureOptions); - - logger.debug("Embeddings retrieved"); - return generateEmbeddingResponse(embeddings); + AzureOpenAiEmbeddingOptions options = AzureOpenAiEmbeddingOptions.builder() + .from(this.defaultOptions) + .merge(embeddingRequest.getOptions()) + .build(); + EmbeddingsOptions azureOptions = options.toAzureOptions(embeddingRequest.getInstructions()); + + var observationContext = EmbeddingModelObservationContext.builder() + .embeddingRequest(embeddingRequest) + .provider(AiProvider.AZURE_OPENAI.value()) + .requestOptions(options) + .build(); + + return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION + .observation(this.observationConvention, DEFAULT_OBSERVATION_CONVENTION, () -> observationContext, + this.observationRegistry) + .observe(() -> { + Embeddings embeddings = this.azureOpenAiClient.getEmbeddings(azureOptions.getModel(), azureOptions); + + logger.debug("Embeddings retrieved"); + var embeddingResponse = generateEmbeddingResponse(embeddings); + observationContext.setResponse(embeddingResponse); + return embeddingResponse; + }); } /** @@ -132,4 +178,13 @@ public AzureOpenAiEmbeddingOptions getDefaultOptions() { return this.defaultOptions; } + /** + * Use the provided convention for reporting observation data + * @param observationConvention The provided convention + */ + public void setObservationConvention(EmbeddingModelObservationConvention observationConvention) { + Assert.notNull(observationConvention, "observationConvention cannot be null"); + this.observationConvention = observationConvention; + } + } diff --git a/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelObservationIT.java b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelObservationIT.java new file mode 100644 index 00000000000..002615b74d4 --- /dev/null +++ b/models/spring-ai-azure-openai/src/test/java/org/springframework/ai/azure/openai/AzureOpenAiEmbeddingModelObservationIT.java @@ -0,0 +1,119 @@ +/* + * Copyright 2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.azure.openai; + +import static org.assertj.core.api.Assertions.assertThat; + +import java.util.List; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.document.MetadataMode; +import org.springframework.ai.embedding.EmbeddingRequest; +import org.springframework.ai.embedding.EmbeddingResponse; +import org.springframework.ai.embedding.EmbeddingResponseMetadata; +import org.springframework.ai.embedding.observation.DefaultEmbeddingModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.HighCardinalityKeyNames; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationDocumentation.LowCardinalityKeyNames; +import org.springframework.ai.observation.conventions.AiOperationType; +import org.springframework.ai.observation.conventions.AiProvider; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.SpringBootConfiguration; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.context.annotation.Bean; + +import com.azure.ai.openai.OpenAIClient; +import com.azure.ai.openai.OpenAIClientBuilder; +import com.azure.core.credential.AzureKeyCredential; + +import io.micrometer.observation.tck.TestObservationRegistry; +import io.micrometer.observation.tck.TestObservationRegistryAssert; + +/** + * Integration tests for observation instrumentation in {@link AzureOpenAiEmbeddingModel}. + * + * @author Christian Tzolov + */ +@SpringBootTest(classes = AzureOpenAiEmbeddingModelObservationIT.Config.class) +@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_API_KEY", matches = ".+") +@EnabledIfEnvironmentVariable(named = "AZURE_OPENAI_ENDPOINT", matches = ".+") +public class AzureOpenAiEmbeddingModelObservationIT { + + @Autowired + TestObservationRegistry observationRegistry; + + @Autowired + AzureOpenAiEmbeddingModel embeddingModel; + + @Test + void observationForEmbeddingOperation() { + var options = AzureOpenAiEmbeddingOptions.builder() + .withDeploymentName("text-embedding-ada-002") + .withDimensions(1536) + .build(); + + EmbeddingRequest embeddingRequest = new EmbeddingRequest(List.of("Here comes the sun"), options); + + EmbeddingResponse embeddingResponse = embeddingModel.call(embeddingRequest); + assertThat(embeddingResponse.getResults()).isNotEmpty(); + + EmbeddingResponseMetadata responseMetadata = embeddingResponse.getMetadata(); + assertThat(responseMetadata).isNotNull(); + + TestObservationRegistryAssert.assertThat(observationRegistry) + .doesNotHaveAnyRemainingCurrentObservation() + .hasObservationWithNameEqualTo(DefaultEmbeddingModelObservationConvention.DEFAULT_NAME) + .that() + .hasContextualNameEqualTo("embedding " + "text-embedding-ada-002") + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_OPERATION_TYPE.asString(), + AiOperationType.EMBEDDING.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.AI_PROVIDER.asString(), AiProvider.AZURE_OPENAI.value()) + .hasLowCardinalityKeyValue(LowCardinalityKeyNames.REQUEST_MODEL.asString(), "text-embedding-ada-002") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.REQUEST_EMBEDDING_DIMENSIONS.asString(), "1536") + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_INPUT_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getPromptTokens())) + .hasHighCardinalityKeyValue(HighCardinalityKeyNames.USAGE_TOTAL_TOKENS.asString(), + String.valueOf(responseMetadata.getUsage().getTotalTokens())) + .hasBeenStarted() + .hasBeenStopped(); + } + + @SpringBootConfiguration + static class Config { + + @Bean + public TestObservationRegistry observationRegistry() { + return TestObservationRegistry.create(); + } + + @Bean + public OpenAIClient openAIClient() { + return new OpenAIClientBuilder().credential(new AzureKeyCredential(System.getenv("AZURE_OPENAI_API_KEY"))) + .endpoint(System.getenv("AZURE_OPENAI_ENDPOINT")) + .buildClient(); + } + + @Bean + public AzureOpenAiEmbeddingModel azureEmbeddingModel(OpenAIClient openAIClient, + TestObservationRegistry observationRegistry) { + return new AzureOpenAiEmbeddingModel(openAIClient, MetadataMode.EMBED, + AzureOpenAiEmbeddingOptions.builder().withDeploymentName("text-embedding-ada-002").build(), + observationRegistry); + } + + } + +} diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java index 813208830b7..ee686d44c27 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/azure/openai/AzureOpenAiAutoConfiguration.java @@ -23,8 +23,11 @@ import org.springframework.ai.azure.openai.AzureOpenAiChatModel; import org.springframework.ai.azure.openai.AzureOpenAiEmbeddingModel; import org.springframework.ai.azure.openai.AzureOpenAiImageModel; +import org.springframework.ai.chat.observation.ChatModelObservationConvention; +import org.springframework.ai.embedding.observation.EmbeddingModelObservationConvention; import org.springframework.ai.model.function.FunctionCallback; import org.springframework.ai.model.function.FunctionCallbackContext; +import org.springframework.beans.factory.ObjectProvider; import org.springframework.boot.autoconfigure.AutoConfiguration; import org.springframework.boot.autoconfigure.condition.ConditionalOnBean; import org.springframework.boot.autoconfigure.condition.ConditionalOnClass; @@ -36,7 +39,6 @@ import org.springframework.util.Assert; import org.springframework.util.StringUtils; -import com.azure.ai.openai.OpenAIClient; import com.azure.ai.openai.OpenAIClientBuilder; import com.azure.core.credential.AzureKeyCredential; import com.azure.core.credential.KeyCredential; @@ -44,6 +46,8 @@ import com.azure.core.util.ClientOptions; import com.azure.core.util.Header; +import io.micrometer.observation.ObservationRegistry; + /** * @author Piotr Olaszewski * @author Soby Chacko @@ -106,10 +110,15 @@ public OpenAIClientBuilder openAIClientWithTokenCredential(AzureOpenAiConnection matchIfMissing = true) public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClientBuilder, AzureOpenAiChatProperties chatProperties, List toolFunctionCallbacks, - FunctionCallbackContext functionCallbackContext) { + FunctionCallbackContext functionCallbackContext, ObjectProvider observationRegistry, + ObjectProvider observationConvention) { + + var chatModel = new AzureOpenAiChatModel(openAIClientBuilder, chatProperties.getOptions(), + functionCallbackContext, toolFunctionCallbacks, + observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)); + observationConvention.ifAvailable(chatModel::setObservationConvention); - return new AzureOpenAiChatModel(openAIClientBuilder, chatProperties.getOptions(), functionCallbackContext, - toolFunctionCallbacks); + return chatModel; } @Bean @@ -117,9 +126,17 @@ public AzureOpenAiChatModel azureOpenAiChatModel(OpenAIClientBuilder openAIClien @ConditionalOnProperty(prefix = AzureOpenAiEmbeddingProperties.CONFIG_PREFIX, name = "enabled", havingValue = "true", matchIfMissing = true) public AzureOpenAiEmbeddingModel azureOpenAiEmbeddingModel(OpenAIClientBuilder openAIClient, - AzureOpenAiEmbeddingProperties embeddingProperties) { - return new AzureOpenAiEmbeddingModel(openAIClient.buildClient(), embeddingProperties.getMetadataMode(), - embeddingProperties.getOptions()); + AzureOpenAiEmbeddingProperties embeddingProperties, ObjectProvider observationRegistry, + ObjectProvider observationConvention) { + + var embeddingModel = new AzureOpenAiEmbeddingModel(openAIClient.buildClient(), + embeddingProperties.getMetadataMode(), embeddingProperties.getOptions(), + observationRegistry.getIfUnique(() -> ObservationRegistry.NOOP)); + + observationConvention.ifAvailable(embeddingModel::setObservationConvention); + + return embeddingModel; + } @Bean