Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -42,7 +42,6 @@
import org.springframework.ai.openai.api.OpenAiApi.EmbeddingList;
import org.springframework.ai.openai.api.common.OpenAiApiConstants;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.lang.Nullable;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

Expand Down Expand Up @@ -148,13 +147,16 @@ public float[] embed(Document document) {

@Override
public EmbeddingResponse call(EmbeddingRequest request) {
OpenAiEmbeddingOptions requestOptions = mergeOptions(request.getOptions(), this.defaultOptions);
OpenAiApi.EmbeddingRequest<List<String>> apiRequest = createRequest(request, requestOptions);
// Before moving any further, build the final request EmbeddingRequest,
// merging runtime and default options.
EmbeddingRequest embeddingRequest = buildEmbeddingRequest(request);

OpenAiApi.EmbeddingRequest<List<String>> apiRequest = createRequest(embeddingRequest);

var observationContext = EmbeddingModelObservationContext.builder()
.embeddingRequest(request)
.provider(OpenAiApiConstants.PROVIDER_NAME)
.requestOptions(requestOptions)
.requestOptions(embeddingRequest.getOptions())
.build();

return EmbeddingModelObservationDocumentation.EMBEDDING_MODEL_OPERATION
Expand Down Expand Up @@ -190,35 +192,32 @@ private DefaultUsage getDefaultUsage(OpenAiApi.Usage usage) {
return new DefaultUsage(usage.promptTokens(), usage.completionTokens(), usage.totalTokens(), usage);
}

private OpenAiApi.EmbeddingRequest<List<String>> createRequest(EmbeddingRequest request,
OpenAiEmbeddingOptions requestOptions) {
private OpenAiApi.EmbeddingRequest<List<String>> createRequest(EmbeddingRequest request) {
OpenAiEmbeddingOptions requestOptions = (OpenAiEmbeddingOptions) request.getOptions();
return new OpenAiApi.EmbeddingRequest<>(request.getInstructions(), requestOptions.getModel(),
requestOptions.getEncodingFormat(), requestOptions.getDimensions(), requestOptions.getUser());
}

/**
* Merge runtime and default {@link EmbeddingOptions} to compute the final options to
* use in the request.
*/
private OpenAiEmbeddingOptions mergeOptions(@Nullable EmbeddingOptions runtimeOptions,
OpenAiEmbeddingOptions defaultOptions) {
var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, EmbeddingOptions.class,
OpenAiEmbeddingOptions.class);

if (runtimeOptionsForProvider == null) {
return defaultOptions;
private EmbeddingRequest buildEmbeddingRequest(EmbeddingRequest embeddingRequest) {
// Process runtime options
OpenAiEmbeddingOptions runtimeOptions = null;
if (embeddingRequest.getOptions() != null) {
runtimeOptions = ModelOptionsUtils.copyToTarget(embeddingRequest.getOptions(), EmbeddingOptions.class,
OpenAiEmbeddingOptions.class);
}

return OpenAiEmbeddingOptions.builder()
OpenAiEmbeddingOptions requestOptions = runtimeOptions == null ? this.defaultOptions : OpenAiEmbeddingOptions
.builder()
// Handle portable embedding options
.model(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
.dimensions(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getDimensions(),
defaultOptions.getDimensions()))
.model(ModelOptionsUtils.mergeOption(runtimeOptions.getModel(), this.defaultOptions.getModel()))
.dimensions(ModelOptionsUtils.mergeOption(runtimeOptions.getDimensions(), defaultOptions.getDimensions()))
// Handle OpenAI specific embedding options
.encodingFormat(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getEncodingFormat(),
.encodingFormat(ModelOptionsUtils.mergeOption(runtimeOptions.getEncodingFormat(),
defaultOptions.getEncodingFormat()))
.user(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getUser(), defaultOptions.getUser()))
.user(ModelOptionsUtils.mergeOption(runtimeOptions.getUser(), this.defaultOptions.getUser()))
.build();

return new EmbeddingRequest(embeddingRequest.getInstructions(), requestOptions);
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2023-2024 the original author or authors.
* Copyright 2023-2025 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,7 +39,6 @@
import org.springframework.ai.openai.metadata.OpenAiImageGenerationMetadata;
import org.springframework.ai.retry.RetryUtils;
import org.springframework.http.ResponseEntity;
import org.springframework.lang.Nullable;
import org.springframework.retry.support.RetryTemplate;
import org.springframework.util.Assert;

Expand Down Expand Up @@ -127,13 +126,16 @@ public OpenAiImageModel(OpenAiImageApi openAiImageApi, OpenAiImageOptions option

@Override
public ImageResponse call(ImagePrompt imagePrompt) {
OpenAiImageOptions requestImageOptions = mergeOptions(imagePrompt.getOptions(), this.defaultOptions);
OpenAiImageApi.OpenAiImageRequest imageRequest = createRequest(imagePrompt, requestImageOptions);
// Before moving any further, build the final request ImagePrompt,
// merging runtime and default options.
ImagePrompt requestImagePrompt = buildRequestImagePrompt(imagePrompt);

OpenAiImageApi.OpenAiImageRequest imageRequest = createRequest(requestImagePrompt);

var observationContext = ImageModelObservationContext.builder()
.imagePrompt(imagePrompt)
.provider(OpenAiApiConstants.PROVIDER_NAME)
.requestOptions(requestImageOptions)
.requestOptions(requestImagePrompt.getOptions())
.build();

return ImageModelObservationDocumentation.IMAGE_MODEL_OPERATION
Expand All @@ -151,14 +153,14 @@ public ImageResponse call(ImagePrompt imagePrompt) {
});
}

private OpenAiImageApi.OpenAiImageRequest createRequest(ImagePrompt imagePrompt,
OpenAiImageOptions requestImageOptions) {
private OpenAiImageApi.OpenAiImageRequest createRequest(ImagePrompt imagePrompt) {
String instructions = imagePrompt.getInstructions().get(0).getText();
OpenAiImageOptions imageOptions = (OpenAiImageOptions) imagePrompt.getOptions();

OpenAiImageApi.OpenAiImageRequest imageRequest = new OpenAiImageApi.OpenAiImageRequest(instructions,
OpenAiImageApi.DEFAULT_IMAGE_MODEL);

return ModelOptionsUtils.merge(requestImageOptions, imageRequest, OpenAiImageApi.OpenAiImageRequest.class);
return ModelOptionsUtils.merge(imageOptions, imageRequest, OpenAiImageApi.OpenAiImageRequest.class);
}

private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageResponse> imageResponseEntity,
Expand All @@ -179,31 +181,29 @@ private ImageResponse convertResponse(ResponseEntity<OpenAiImageApi.OpenAiImageR
return new ImageResponse(imageGenerationList, openAiImageResponseMetadata);
}

/**
* Merge runtime and default {@link ImageOptions} to compute the final options to use
* in the request.
*/
private OpenAiImageOptions mergeOptions(@Nullable ImageOptions runtimeOptions, OpenAiImageOptions defaultOptions) {
var runtimeOptionsForProvider = ModelOptionsUtils.copyToTarget(runtimeOptions, ImageOptions.class,
OpenAiImageOptions.class);

if (runtimeOptionsForProvider == null) {
return defaultOptions;
private ImagePrompt buildRequestImagePrompt(ImagePrompt imagePrompt) {
// Process runtime options
OpenAiImageOptions runtimeOptions = null;
if (imagePrompt.getOptions() != null) {
runtimeOptions = ModelOptionsUtils.copyToTarget(imagePrompt.getOptions(), ImageOptions.class,
OpenAiImageOptions.class);
}

return OpenAiImageOptions.builder()
OpenAiImageOptions requestOptions = runtimeOptions == null ? this.defaultOptions : OpenAiImageOptions.builder()
// Handle portable image options
.model(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getModel(), defaultOptions.getModel()))
.N(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getN(), defaultOptions.getN()))
.responseFormat(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getResponseFormat(),
.model(ModelOptionsUtils.mergeOption(runtimeOptions.getModel(), defaultOptions.getModel()))
.N(ModelOptionsUtils.mergeOption(runtimeOptions.getN(), defaultOptions.getN()))
.responseFormat(ModelOptionsUtils.mergeOption(runtimeOptions.getResponseFormat(),
defaultOptions.getResponseFormat()))
.width(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getWidth(), defaultOptions.getWidth()))
.height(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getHeight(), defaultOptions.getHeight()))
.style(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getStyle(), defaultOptions.getStyle()))
.width(ModelOptionsUtils.mergeOption(runtimeOptions.getWidth(), defaultOptions.getWidth()))
.height(ModelOptionsUtils.mergeOption(runtimeOptions.getHeight(), defaultOptions.getHeight()))
.style(ModelOptionsUtils.mergeOption(runtimeOptions.getStyle(), defaultOptions.getStyle()))
// Handle OpenAI specific image options
.quality(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getQuality(), defaultOptions.getQuality()))
.user(ModelOptionsUtils.mergeOption(runtimeOptionsForProvider.getUser(), defaultOptions.getUser()))
.quality(ModelOptionsUtils.mergeOption(runtimeOptions.getQuality(), defaultOptions.getQuality()))
.user(ModelOptionsUtils.mergeOption(runtimeOptions.getUser(), defaultOptions.getUser()))
.build();

return new ImagePrompt(imagePrompt.getInstructions(), requestOptions);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.List;

import org.springframework.ai.model.ModelRequest;
import org.springframework.lang.Nullable;

/**
* Request to embed a list of input instructions.
Expand All @@ -29,9 +30,10 @@ public class EmbeddingRequest implements ModelRequest<List<String>> {

private final List<String> inputs;

@Nullable
private final EmbeddingOptions options;

public EmbeddingRequest(List<String> inputs, EmbeddingOptions options) {
public EmbeddingRequest(List<String> inputs, @Nullable EmbeddingOptions options) {
this.inputs = inputs;
this.options = options;
}
Expand All @@ -42,6 +44,7 @@ public List<String> getInstructions() {
}

@Override
@Nullable
public EmbeddingOptions getOptions() {
return this.options;
}
Expand Down