diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java index 16bf6a24bf7..c2aa0e8df86 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/VertexAiEmbeddingUtils.java @@ -140,7 +140,7 @@ public Struct build() { Struct.Builder textBuilder = Struct.newBuilder(); textBuilder.putFields("content", valueOf(this.content)); if (StringUtils.hasText(this.taskType)) { - textBuilder.putFields("taskType", valueOf(this.taskType)); + textBuilder.putFields("task_type", valueOf(this.taskType)); } if (StringUtils.hasText(this.title)) { textBuilder.putFields("title", valueOf(this.title)); diff --git a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java index 42d16407710..4f384b35b26 100644 --- a/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java +++ b/models/spring-ai-vertex-ai-embedding/src/main/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingOptions.java @@ -187,6 +187,9 @@ public Builder from(VertexAiTextEmbeddingOptions fromOptions) { if (fromOptions.getTaskType() != null) { this.options.setTaskType(fromOptions.getTaskType()); } + if (fromOptions.getAutoTruncate() != null) { + this.options.setAutoTruncate(fromOptions.getAutoTruncate()); + } if (StringUtils.hasText(fromOptions.getTitle())) { this.options.setTitle(fromOptions.getTitle()); } diff --git a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java index ddfb9bfa106..40b7ffd881b 100644 --- a/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java +++ b/models/spring-ai-vertex-ai-embedding/src/test/java/org/springframework/ai/vertexai/embedding/text/VertexAiTextEmbeddingModelIT.java @@ -1,5 +1,5 @@ /* - * Copyright 2023-2024 the original author or authors. + * Copyright 2023-2025 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,13 @@ import java.util.List; +import com.google.cloud.aiplatform.v1.EndpointName; +import com.google.cloud.aiplatform.v1.PredictRequest; +import com.google.cloud.aiplatform.v1.PredictionServiceClient; +import com.google.cloud.aiplatform.v1.PredictionServiceSettings; +import com.google.protobuf.Struct; +import com.google.protobuf.Value; +import org.junit.jupiter.api.Test; import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.ValueSource; @@ -30,6 +37,7 @@ import org.springframework.boot.test.context.SpringBootTest; import org.springframework.context.annotation.Bean; +import static java.util.stream.Collectors.toList; import static org.assertj.core.api.Assertions.assertThat; @SpringBootTest(classes = VertexAiTextEmbeddingModelIT.Config.class) @@ -65,6 +73,116 @@ void defaultEmbedding(String modelName) { assertThat(this.embeddingModel.dimensions()).isEqualTo(768); } + // Fixing https://github.com/spring-projects/spring-ai/issues/2168 + @Test + void testTaskTypeProperty() { + // Use text-embedding-005 model + VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder() + .model("text-embedding-005") + .taskType(VertexAiTextEmbeddingOptions.TaskType.RETRIEVAL_DOCUMENT) + .build(); + + String text = "Test text for embedding"; + + // Generate embedding using Spring AI with RETRIEVAL_DOCUMENT task type + EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of(text), options)); + + assertThat(embeddingResponse.getResults()).hasSize(1); + assertThat(embeddingResponse.getResults().get(0).getOutput()).isNotNull(); + + // Get the embedding result + float[] springAiEmbedding = embeddingResponse.getResults().get(0).getOutput(); + + // Now generate the same embedding using Google SDK directly with + // RETRIEVAL_DOCUMENT + float[] googleSdkDocumentEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_DOCUMENT"); + + // Also generate embedding using Google SDK with RETRIEVAL_QUERY (which is the + // default) + float[] googleSdkQueryEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_QUERY"); + + // Spring AI embedding should match with what gets generated by Google SDK with + // RETRIEVAL_DOCUMENT task type. + assertThat(springAiEmbedding) + .as("Spring AI embedding with RETRIEVAL_DOCUMENT should match Google SDK RETRIEVAL_DOCUMENT embedding") + .isEqualTo(googleSdkDocumentEmbedding); + + // Spring AI embedding which uses RETRIEVAL_DOCUMENT task_type should not match + // with what gets generated by + // Google SDK with RETRIEVAL_QUERY task type. + assertThat(springAiEmbedding) + .as("Spring AI embedding with RETRIEVAL_DOCUMENT should NOT match Google SDK RETRIEVAL_QUERY embedding") + .isNotEqualTo(googleSdkQueryEmbedding); + } + + // Fixing https://github.com/spring-projects/spring-ai/issues/2168 + @Test + void testDefaultTaskTypeBehavior() { + // Test default behavior without explicitly setting task type + VertexAiTextEmbeddingOptions options = VertexAiTextEmbeddingOptions.builder() + .model("text-embedding-005") + .build(); + + String text = "Test text for default embedding"; + + EmbeddingResponse embeddingResponse = this.embeddingModel.call(new EmbeddingRequest(List.of(text), options)); + + assertThat(embeddingResponse.getResults()).hasSize(1); + + float[] springAiDefaultEmbedding = embeddingResponse.getResults().get(0).getOutput(); + + // According to documentation, default should be RETRIEVAL_DOCUMENT + float[] googleSdkDocumentEmbedding = getEmbeddingUsingGoogleSdk(text, "RETRIEVAL_DOCUMENT"); + + assertThat(springAiDefaultEmbedding) + .as("Default Spring AI embedding should match Google SDK RETRIEVAL_DOCUMENT embedding") + .isEqualTo(googleSdkDocumentEmbedding); + } + + private float[] getEmbeddingUsingGoogleSdk(String text, String taskType) { + try { + String endpoint = String.format("%s-aiplatform.googleapis.com:443", + System.getenv("VERTEX_AI_GEMINI_LOCATION")); + String project = System.getenv("VERTEX_AI_GEMINI_PROJECT_ID"); + + PredictionServiceSettings settings = PredictionServiceSettings.newBuilder().setEndpoint(endpoint).build(); + + EndpointName endpointName = EndpointName.ofProjectLocationPublisherModelName(project, + System.getenv("VERTEX_AI_GEMINI_LOCATION"), "google", "text-embedding-005"); + + try (PredictionServiceClient client = PredictionServiceClient.create(settings)) { + PredictRequest.Builder request = PredictRequest.newBuilder().setEndpoint(endpointName.toString()); + + request.addInstances(Value.newBuilder() + .setStructValue(Struct.newBuilder() + .putFields("content", Value.newBuilder().setStringValue(text).build()) + .putFields("task_type", Value.newBuilder().setStringValue(taskType).build()) + .build()) + .build()); + + var prediction = client.predict(request.build()).getPredictionsList().get(0); + Value embeddings = prediction.getStructValue().getFieldsOrThrow("embeddings"); + Value values = embeddings.getStructValue().getFieldsOrThrow("values"); + + List floatList = values.getListValue() + .getValuesList() + .stream() + .map(Value::getNumberValue) + .map(Double::floatValue) + .collect(toList()); + + float[] floatArray = new float[floatList.size()]; + for (int i = 0; i < floatList.size(); i++) { + floatArray[i] = floatList.get(i); + } + return floatArray; + } + } + catch (Exception e) { + throw new RuntimeException("Failed to get embedding from Google SDK", e); + } + } + @SpringBootConfiguration static class Config {