Skip to content

Commit ebd29e0

Browse files
ilayaperumalgtzolov
authored andcommitted
GH-1826 Fix EmbeddingModel's usage on Document#embedding
- Since the Document object's reference to the `embedding` is deprecated and will be removed, the VectorStore implementations require a way to store the embedding of the corresponding Document objects - One way to fix this is, to have the EmbeddingModel#embed to return the embeddings in the same order as that of the Documents passed to it. - Since both the Document and embedding collections use the List object, their iteration operation will make sure to keep them in line with the same order. - A fix is required to preserve the order when batching strategy is applied. - Updated the Javadoc for BatchingStrategy - Fixed the Document List order in TokenCountBatchingStrategy - Refactored the vector store implementations to update this change Resolves #GH-1826
1 parent 6cfe5e7 commit ebd29e0

File tree

24 files changed

+118
-94
lines changed

24 files changed

+118
-94
lines changed

models/spring-ai-openai/src/test/java/org/springframework/ai/openai/embedding/EmbeddingIT.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,12 +66,12 @@ void defaultEmbedding() {
6666
@Test
6767
void embeddingBatchDocuments() throws Exception {
6868
assertThat(this.embeddingModel).isNotNull();
69-
List<float[]> embedded = this.embeddingModel.embed(
69+
List<float[]> embeddings = this.embeddingModel.embed(
7070
List.of(new Document("Hello world"), new Document("Hello Spring"), new Document("Hello Spring AI!")),
7171
OpenAiEmbeddingOptions.builder().withModel(OpenAiApi.DEFAULT_EMBEDDING_MODEL).build(),
7272
new TokenCountBatchingStrategy());
73-
assertThat(embedded.size()).isEqualTo(3);
74-
embedded.forEach(embedding -> assertThat(embedding.length).isEqualTo(this.embeddingModel.dimensions()));
73+
assertThat(embeddings.size()).isEqualTo(3);
74+
embeddings.forEach(embedding -> assertThat(embedding.length).isEqualTo(this.embeddingModel.dimensions()));
7575
}
7676

7777
@Test

spring-ai-core/src/main/java/org/springframework/ai/embedding/BatchingStrategy.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ public interface BatchingStrategy {
3131

3232
/**
3333
* {@link EmbeddingModel} implementations can call this method to optimize embedding
34-
* tokens. The incoming collection of {@link Document}s are split into su-batches.
34+
* tokens. The incoming collection of {@link Document}s are split into sub-batches. It
35+
* is important to preserve the order of the list of {@link Document}s when batching
36+
* as they are mapped to their corresponding embeddings by their order.
3537
* @param documents to batch
3638
* @return a list of sub-batches that contain {@link Document}s.
3739
*/

spring-ai-core/src/main/java/org/springframework/ai/embedding/EmbeddingModel.java

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,23 @@ default List<float[]> embed(List<String> texts) {
7878
* @param options {@link EmbeddingOptions}.
7979
* @param batchingStrategy {@link BatchingStrategy}.
8080
* @return a list of float[] that represents the vectors for the incoming
81-
* {@link Document}s.
81+
* {@link Document}s. The returned list is expected to be in the same order of the
82+
* {@link Document} list.
8283
*/
8384
default List<float[]> embed(List<Document> documents, EmbeddingOptions options, BatchingStrategy batchingStrategy) {
8485
Assert.notNull(documents, "Documents must not be null");
85-
List<float[]> embeddings = new ArrayList<>();
86-
86+
List<float[]> embeddings = new ArrayList<>(documents.size());
8787
List<List<Document>> batch = batchingStrategy.batch(documents);
88-
8988
for (List<Document> subBatch : batch) {
9089
List<String> texts = subBatch.stream().map(Document::getContent).toList();
9190
EmbeddingRequest request = new EmbeddingRequest(texts, options);
9291
EmbeddingResponse response = this.call(request);
9392
for (int i = 0; i < subBatch.size(); i++) {
94-
Document document = subBatch.get(i);
95-
float[] output = response.getResults().get(i).getOutput();
96-
embeddings.add(output);
97-
document.setEmbedding(output);
93+
embeddings.add(response.getResults().get(i).getOutput());
9894
}
9995
}
96+
Assert.isTrue(embeddings.size() == documents.size(),
97+
"Embeddings must have the same number as that of the documents");
10098
return embeddings;
10199
}
102100

spring-ai-core/src/main/java/org/springframework/ai/embedding/TokenCountBatchingStrategy.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
package org.springframework.ai.embedding;
1818

1919
import java.util.ArrayList;
20-
import java.util.HashMap;
20+
import java.util.LinkedHashMap;
2121
import java.util.List;
2222
import java.util.Map;
2323

@@ -139,7 +139,9 @@ public List<List<Document>> batch(List<Document> documents) {
139139
List<List<Document>> batches = new ArrayList<>();
140140
int currentSize = 0;
141141
List<Document> currentBatch = new ArrayList<>();
142-
Map<Document, Integer> documentTokens = new HashMap<>();
142+
// Make sure the documentTokens' entry order is preserved by making it a
143+
// LinkedHashMap.
144+
Map<Document, Integer> documentTokens = new LinkedHashMap<>();
143145

144146
for (Document document : documents) {
145147
int tokenCount = this.tokenCountEstimator

vector-stores/spring-ai-azure-cosmos-db-store/src/main/java/org/springframework/ai/vectorstore/CosmosDBVectorStore.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -204,13 +204,14 @@ private JsonNode mapCosmosDocument(Document document, float[] queryEmbedding) {
204204
public void doAdd(List<Document> documents) {
205205

206206
// Batch the documents based on the batching strategy
207-
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
207+
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
208+
this.batchingStrategy);
208209

209210
// Create a list to hold both the CosmosItemOperation and the corresponding
210211
// document ID
211212
List<ImmutablePair<String, CosmosItemOperation>> itemOperationsWithIds = documents.stream().map(doc -> {
212-
CosmosItemOperation operation = CosmosBulkOperations
213-
.getCreateItemOperation(mapCosmosDocument(doc, doc.getEmbedding()), new PartitionKey(doc.getId()));
213+
CosmosItemOperation operation = CosmosBulkOperations.getCreateItemOperation(
214+
mapCosmosDocument(doc, embeddings.get(documents.indexOf(doc))), new PartitionKey(doc.getId()));
214215
return new ImmutablePair<>(doc.getId(), operation); // Pair the document ID
215216
// with the operation
216217
}).toList();

vector-stores/spring-ai-azure-store/src/main/java/org/springframework/ai/vectorstore/azure/AzureVectorStore.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -223,12 +223,13 @@ public void doAdd(List<Document> documents) {
223223
return; // nothing to do;
224224
}
225225

226-
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
226+
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
227+
this.batchingStrategy);
227228

228229
final var searchDocuments = documents.stream().map(document -> {
229230
SearchDocument searchDocument = new SearchDocument();
230231
searchDocument.put(ID_FIELD_NAME, document.getId());
231-
searchDocument.put(EMBEDDING_FIELD_NAME, document.getEmbedding());
232+
searchDocument.put(EMBEDDING_FIELD_NAME, embeddings.get(documents.indexOf(document)));
232233
searchDocument.put(CONTENT_FIELD_NAME, document.getContent());
233234
searchDocument.put(METADATA_FIELD_NAME, new JSONObject(document.getMetadata()).toJSONString());
234235

@@ -327,7 +328,6 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
327328
.content(entry.content)
328329
.metadata(metadata)
329330
.score(result.getScore())
330-
.embedding(EmbeddingUtils.toPrimitive(entry.embedding))
331331
.build();
332332
})
333333
.collect(Collectors.toList());

vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStore.java

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,8 @@ private static Float[] toFloatArray(float[] embedding) {
181181
public void doAdd(List<Document> documents) {
182182
var futures = new CompletableFuture[documents.size()];
183183

184-
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
184+
List<float[]> embeddings = this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(),
185+
this.batchingStrategy);
185186

186187
int i = 0;
187188
for (Document d : documents) {
@@ -196,7 +197,8 @@ public void doAdd(List<Document> documents) {
196197

197198
builder = builder.setString(this.conf.schema.content(), d.getContent())
198199
.setVector(this.conf.schema.embedding(),
199-
CqlVector.newInstance(EmbeddingUtils.toList(d.getEmbedding())), Float.class);
200+
CqlVector.newInstance(EmbeddingUtils.toList(embeddings.get(documents.indexOf(d)))),
201+
Float.class);
200202

201203
for (var metadataColumn : this.conf.schema.metadataColumns()
202204
.stream()
@@ -265,10 +267,6 @@ public List<Document> doSimilaritySearch(SearchRequest request) {
265267
.score((double) score)
266268
.build();
267269

268-
if (this.conf.returnEmbeddings) {
269-
doc.setEmbedding(EmbeddingUtils
270-
.toPrimitive(row.getVector(this.conf.schema.embedding(), Float.class).stream().toList()));
271-
}
272270
documents.add(doc);
273271
}
274272
return documents;

vector-stores/spring-ai-cassandra-store/src/main/java/org/springframework/ai/vectorstore/CassandraVectorStoreConfig.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@ public final class CassandraVectorStoreConfig implements AutoCloseable {
9090

9191
final boolean disallowSchemaChanges;
9292

93+
// TODO: Remove this flag as the document no longer holds embeddings.
94+
@Deprecated(since = "1.0.0-M5", forRemoval = true)
9395
final boolean returnEmbeddings;
9496

9597
final DocumentIdTranslator documentIdTranslator;

vector-stores/spring-ai-cassandra-store/src/test/java/org/springframework/ai/vectorstore/CassandraVectorStoreIT.java

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,12 @@ void addAndSearch() {
122122

123123
List<Document> documents = documents();
124124
store.add(documents);
125-
for (Document d : documents) {
126-
assertThat(d.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNotNull(),
127-
e -> assertThat(e).isNotEmpty());
128-
}
129125

130126
List<Document> results = store.similaritySearch(SearchRequest.query("Spring").withTopK(1));
131127

132128
assertThat(results).hasSize(1);
133129
Document resultDoc = results.get(0);
134130
assertThat(resultDoc.getId()).isEqualTo(documents().get(0).getId());
135-
assertThat(resultDoc.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNull(),
136-
e -> assertThat(e).isEmpty());
137131

138132
assertThat(resultDoc.getContent()).contains(
139133
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");
@@ -159,17 +153,12 @@ void addAndSearchReturnEmbeddings() {
159153
try (CassandraVectorStore store = createTestStore(context, builder)) {
160154
List<Document> documents = documents();
161155
store.add(documents);
162-
for (Document d : documents) {
163-
assertThat(d.getEmbedding()).satisfiesAnyOf(e -> assertThat(e).isNotNull(),
164-
e -> assertThat(e).isNotEmpty());
165-
}
166156

167157
List<Document> results = store.similaritySearch(SearchRequest.query("Spring").withTopK(1));
168158

169159
assertThat(results).hasSize(1);
170160
Document resultDoc = results.get(0);
171161
assertThat(resultDoc.getId()).isEqualTo(documents().get(0).getId());
172-
assertThat(resultDoc.getEmbedding()).isNotEmpty();
173162

174163
assertThat(resultDoc.getContent()).contains(
175164
"Spring AI provides abstractions that serve as the foundation for developing AI applications.");

vector-stores/spring-ai-chroma-store/src/main/java/org/springframework/ai/vectorstore/ChromaVectorStore.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -145,14 +145,14 @@ public void doAdd(@NonNull List<Document> documents) {
145145
List<String> contents = new ArrayList<>();
146146
List<float[]> embeddings = new ArrayList<>();
147147

148-
this.embeddingModel.embed(documents, EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
148+
List<float[]> documentEmbeddings = this.embeddingModel.embed(documents,
149+
EmbeddingOptionsBuilder.builder().build(), this.batchingStrategy);
149150

150151
for (Document document : documents) {
151152
ids.add(document.getId());
152153
metadatas.add(document.getMetadata());
153154
contents.add(document.getContent());
154-
document.setEmbedding(document.getEmbedding());
155-
embeddings.add(document.getEmbedding());
155+
embeddings.add(documentEmbeddings.get(documents.indexOf(document)));
156156
}
157157

158158
this.chromaApi.upsertEmbeddings(this.collectionId,
@@ -192,12 +192,12 @@ public Optional<Boolean> doDelete(@NonNull List<String> idList) {
192192
if (metadata == null) {
193193
metadata = new HashMap<>();
194194
}
195+
195196
metadata.put(DocumentMetadata.DISTANCE.value(), distance);
196197
Document document = Document.builder()
197198
.id(id)
198199
.content(content)
199200
.metadata(metadata)
200-
.embedding(chromaEmbedding.embedding())
201201
.score(1.0 - distance)
202202
.build();
203203
responseDocuments.add(document);

0 commit comments

Comments
 (0)