From 39e3c6d60b76885f54372543137b1ca519d88fbb Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Wed, 19 Mar 2025 13:27:40 -0700 Subject: [PATCH] openai: Adopt new strategy for ObservationContext Relates to gh-2518 Signed-off-by: Thomas Vitale --- .../ai/openai/OpenAiEmbeddingModel.java | 47 ++++++++-------- .../ai/openai/OpenAiImageModel.java | 54 +++++++++---------- .../ai/embedding/EmbeddingRequest.java | 5 +- 3 files changed, 54 insertions(+), 52 deletions(-) diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java index 712b6b05f2f..5bed3ea2b96 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiEmbeddingModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,7 +42,6 @@ import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList; import org.springframework.ai.openai.api.common.OpenAiApiConstants; import org.springframework.ai.retry.RetryUtils; -import org.springframework.lang.Nullable; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; @@ -148,13 +147,16 @@ public float[] embed(Document document) { @Override public EmbeddingResponse call(EmbeddingRequest request) { - OpenAiEmbeddingOptions requestOptions = mergeOptions(request.getOptions(), this.defaultOptions); - OpenAiApi.EmbeddingRequest> apiRequest = createRequest(request, requestOptions); + // Before moving any further, build the final request EmbeddingRequest, + // merging runtime and default options. + EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request); + + OpenAiApi.EmbeddingRequest> apiRequest = createRequest(embeddingRequest); var observationContext = EmbeddingModelObservationContext.builder() .embeddingRequest(request) .provider(OpenAiApiConstants.PROVIDER_NAME) - .requestOptions(requestOptions) + .requestOptions(embeddingRequest.getOptions()) .build(); return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION @@ -190,35 +192,32 @@ private DefaultUsage getDefaultUsage(OpenAiApi.Usage usage) { return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage); } - private OpenAiApi.EmbeddingRequest> createRequest(EmbeddingRequest request, - OpenAiEmbeddingOptions requestOptions) { + private OpenAiApi.EmbeddingRequest> createRequest(EmbeddingRequest request) { + OpenAiEmbeddingOptions requestOptions = (OpenAiEmbeddingOptions) request.getOptions(); return new OpenAiApi.EmbeddingRequest<>(request.getInstructions(), requestOptions.getModel(), requestOptions.getEncodingFormat(), requestOptions.getDimensions(), requestOptions.getUser()); } - /** - * Merge runtime and default {@link EmbeddingOptions} to compute the final options to - * use in the request. - */ - private OpenAiEmbeddingOptions mergeOptions(@Nullable EmbeddingOptions runtimeOptions, - OpenAiEmbeddingOptions defaultOptions) { - var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, EmbeddingOptions.class, - OpenAiEmbeddingOptions.class); - - if (runtimeOptionsForProvider == null) { - return defaultOptions; + private EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) { + // Process runtime options + OpenAiEmbeddingOptions runtimeOptions = null; + if (embeddingRequest.getOptions() != null) { + runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class, + OpenAiEmbeddingOptions.class); } - return OpenAiEmbeddingOptions.builder() + OpenAiEmbeddingOptions requestOptions = runtimeOptions == null ? this.defaultOptions : OpenAiEmbeddingOptions + .builder() // Handle portable embedding options - .model(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel())) - .dimensions(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getDimensions(), - defaultOptions.getDimensions())) + .model(ModelOptionsUtils.mergeOption(runtimeOptions.getModel(), this.defaultOptions.getModel())) + .dimensions(ModelOptionsUtils.mergeOption(runtimeOptions.getDimensions(), defaultOptions.getDimensions())) // Handle OpenAI specific embedding options - .encodingFormat(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getEncodingFormat(), + .encodingFormat(ModelOptionsUtils.mergeOption(runtimeOptions.getEncodingFormat(), defaultOptions.getEncodingFormat())) - .user(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getUser(), defaultOptions.getUser())) + .user(ModelOptionsUtils.mergeOption(runtimeOptions.getUser(), this.defaultOptions.getUser())) .build(); + + return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions); } /** diff --git a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java index 336a44d93bd..88096f15399 100644 --- a/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java +++ b/models/spring-ai-openai/src/main/java/org/springframework/ai/openai/OpenAiImageModel.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,7 +39,6 @@ import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata; import org.springframework.ai.retry.RetryUtils; import org.springframework.http.ResponseEntity; -import org.springframework.lang.Nullable; import org.springframework.retry.support.RetryTemplate; import org.springframework.util.Assert; @@ -127,13 +126,16 @@ public OpenAiImageModel(OpenAiImageApi openAiImageApi, OpenAiImageOptions option @Override public ImageResponse call(ImagePrompt imagePrompt) { - OpenAiImageOptions requestImageOptions = mergeOptions(imagePrompt.getOptions(), this.defaultOptions); - OpenAiImageApi.OpenAiImageRequest imageRequest = createRequest(imagePrompt, requestImageOptions); + // Before moving any further, build the final request ImagePrompt, + // merging runtime and default options. + ImagePrompt requestImagePrompt = buildRequestImagePrompt(imagePrompt); + + OpenAiImageApi.OpenAiImageRequest imageRequest = createRequest(requestImagePrompt); var observationContext = ImageModelObservationContext.builder() .imagePrompt(imagePrompt) .provider(OpenAiApiConstants.PROVIDER_NAME) - .requestOptions(requestImageOptions) + .requestOptions(requestImagePrompt.getOptions()) .build(); return ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION @@ -151,14 +153,14 @@ public ImageResponse call(ImagePrompt imagePrompt) { }); } - private OpenAiImageApi.OpenAiImageRequest createRequest(ImagePrompt imagePrompt, - OpenAiImageOptions requestImageOptions) { + private OpenAiImageApi.OpenAiImageRequest createRequest(ImagePrompt imagePrompt) { String instructions = imagePrompt.getInstructions().get(0).getText(); + OpenAiImageOptions imageOptions = (OpenAiImageOptions) imagePrompt.getOptions(); OpenAiImageApi.OpenAiImageRequest imageRequest = new OpenAiImageApi.OpenAiImageRequest(instructions, OpenAiImageApi.DEFAULT_IMAGE_MODEL); - return ModelOptionsUtils.merge(requestImageOptions, imageRequest, OpenAiImageApi.OpenAiImageRequest.class); + return ModelOptionsUtils.merge(imageOptions, imageRequest, OpenAiImageApi.OpenAiImageRequest.class); } private ImageResponse convertResponse(ResponseEntity imageResponseEntity, @@ -179,31 +181,29 @@ private ImageResponse convertResponse(ResponseEntity> { private final List inputs; + @Nullable private final EmbeddingOptions options; - public EmbeddingRequest(List inputs, EmbeddingOptions options) { + public EmbeddingRequest(List inputs, @Nullable EmbeddingOptions options) { this.inputs = inputs; this.options = options; } @@ -42,6 +44,7 @@ public List getInstructions() { } @Override + @Nullable public EmbeddingOptions getOptions() { return this.options; }