From 25d6169f6d8e3a07cff972e60b4440fdceffb708 Mon Sep 17 00:00:00 2001 From: Ilayaperumal Gopinathan Date: Mon, 25 Nov 2024 15:00:52 +0000 Subject: [PATCH] Refactor SimpleVectorStore - Remove SimpleVectorStore's dependency on deprecated embeddings from Document object - Create a custom Content object that represents the SimpleVectorStore's contents and embedding --- .../ai/vectorstore/SimpleVectorStore.java | 37 ++-- .../vectorstore/SimpleVectorStoreContent.java | 159 ++++++++++++++++++ .../SimpleVectorStoreSimilarityTests.java | 46 +++++ 3 files changed, 228 insertions(+), 14 deletions(-) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStoreContent.java create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreSimilarityTests.java diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java index f2e558eb6a8..16755447f20 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStore.java @@ -67,6 +67,7 @@ * @author Mark Pollack * @author Christian Tzolov * @author Sebastien Deleuze + * @author Ilayaperumal Gopinathan */ public class SimpleVectorStore extends AbstractObservationVectorStore { @@ -74,7 +75,7 @@ public class SimpleVectorStore extends AbstractObservationVectorStore { private final ObjectMapper objectMapper; - protected Map store = new ConcurrentHashMap<>(); + protected Map store = new ConcurrentHashMap<>(); protected EmbeddingModel embeddingModel; @@ -97,8 +98,10 @@ public void doAdd(List documents) { for (Document document : documents) { logger.info("Calling EmbeddingModel for document id = {}", document.getId()); float[] embedding = this.embeddingModel.embed(document); - document.setEmbedding(embedding); - this.store.put(document.getId(), document); + SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent(document.getId(), + document.getContent(), document.getMetadata()); + storeContent.setEmbedding(embedding); + this.store.put(document.getId(), storeContent); } } @@ -120,12 +123,12 @@ public List doSimilaritySearch(SearchRequest request) { float[] userQueryEmbedding = getUserQueryEmbedding(request.getQuery()); return this.store.values() .stream() - .map(entry -> new Similarity(entry.getId(), + .map(entry -> new Similarity(entry, EmbeddingMath.cosineSimilarity(userQueryEmbedding, entry.getEmbedding()))) .filter(s -> s.score >= request.getSimilarityThreshold()) .sorted(Comparator.comparingDouble(s -> s.score).reversed()) .limit(request.getTopK()) - .map(s -> this.store.get(s.key)) + .map(s -> s.getDocument()) .toList(); } @@ -176,12 +179,11 @@ public void save(File file) { * @param file the file to load the vector store content */ public void load(File file) { - TypeReference> typeRef = new TypeReference<>() { + TypeReference> typeRef = new TypeReference<>() { }; try { - Map deserializedMap = this.objectMapper.readValue(file, typeRef); - this.store = deserializedMap; + this.store = this.objectMapper.readValue(file, typeRef); } catch (IOException ex) { throw new RuntimeException(ex); @@ -193,12 +195,11 @@ public void load(File file) { * @param resource the resource to load the vector store content */ public void load(Resource resource) { - TypeReference> typeRef = new TypeReference<>() { + TypeReference> typeRef = new TypeReference<>() { }; try { - Map deserializedMap = this.objectMapper.readValue(resource.getInputStream(), typeRef); - this.store = deserializedMap; + this.store = this.objectMapper.readValue(resource.getInputStream(), typeRef); } catch (IOException ex) { throw new RuntimeException(ex); @@ -232,15 +233,23 @@ public VectorStoreObservationContext.Builder createObservationContextBuilder(Str public static class Similarity { - private String key; + private SimpleVectorStoreContent content; private double score; - public Similarity(String key, double score) { - this.key = key; + public Similarity(SimpleVectorStoreContent content, double score) { + this.content = content; this.score = score; } + Document getDocument() { + return Document.builder() + .withId(this.content.getId()) + .withContent(this.content.getContent()) + .withMetadata(this.content.getMetadata()) + .build(); + } + } public final class EmbeddingMath { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStoreContent.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStoreContent.java new file mode 100644 index 00000000000..dc4f8eb1dad --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/SimpleVectorStoreContent.java @@ -0,0 +1,159 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import java.util.HashMap; +import java.util.Map; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import org.springframework.ai.document.id.IdGenerator; +import org.springframework.ai.document.id.RandomIdGenerator; +import org.springframework.ai.model.Content; +import org.springframework.util.Assert; + +/** + * A simple {@link Content} object which represents the content, metadata along its + * embeddings. + */ +public class SimpleVectorStoreContent implements Content { + + /** + * Unique ID + */ + private final String id; + + /** + * Document content. + */ + private final String content; + + /** + * Metadata for the document. It should not be nested and values should be restricted + * to string, int, float, boolean for simple use with Vector Dbs. + */ + private Map metadata; + + /** + * Embedding of the document. Note: ephemeral field. + */ + @JsonProperty(index = 100) + private float[] embedding = new float[0]; + + @JsonCreator(mode = JsonCreator.Mode.PROPERTIES) + public SimpleVectorStoreContent(@JsonProperty("content") String content) { + this(content, new HashMap<>()); + } + + public SimpleVectorStoreContent(String content, Map metadata) { + this(content, metadata, new RandomIdGenerator()); + } + + public SimpleVectorStoreContent(String content, Map metadata, IdGenerator idGenerator) { + this(idGenerator.generateId(content, metadata), content, metadata); + } + + public SimpleVectorStoreContent(String id, String content, Map metadata) { + Assert.hasText(id, "id must not be null or empty"); + Assert.notNull(content, "content must not be null"); + Assert.notNull(metadata, "metadata must not be null"); + + this.id = id; + this.content = content; + this.metadata = metadata; + } + + public String getId() { + return this.id; + } + + @Override + public String getContent() { + return this.content; + } + + @Override + public Map getMetadata() { + return this.metadata; + } + + public float[] getEmbedding() { + return this.embedding; + } + + public void setEmbedding(float[] embedding) { + Assert.notNull(embedding, "embedding must not be null"); + this.embedding = embedding; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.id == null) ? 0 : this.id.hashCode()); + result = prime * result + ((this.metadata == null) ? 0 : this.metadata.hashCode()); + result = prime * result + ((this.content == null) ? 0 : this.content.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + SimpleVectorStoreContent other = (SimpleVectorStoreContent) obj; + if (this.id == null) { + if (other.id != null) { + return false; + } + } + else if (!this.id.equals(other.id)) { + return false; + } + if (this.metadata == null) { + if (other.metadata != null) { + return false; + } + } + else if (!this.metadata.equals(other.metadata)) { + return false; + } + if (this.content == null) { + if (other.content != null) { + return false; + } + } + else if (!this.content.equals(other.content)) { + return false; + } + return true; + } + + @Override + public String toString() { + return "Document{" + "id='" + this.id + '\'' + ", metadata=" + this.metadata + ", content='" + this.content + + '}'; + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreSimilarityTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreSimilarityTests.java new file mode 100644 index 00000000000..cf2d7476de6 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/SimpleVectorStoreSimilarityTests.java @@ -0,0 +1,46 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore; + +import java.util.HashMap; +import java.util.Map; + +import org.junit.Test; + +import org.springframework.ai.document.Document; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Ilayaperumal Gopinathan + */ +public class SimpleVectorStoreSimilarityTests { + + @Test + public void testSimilarity() { + Map metadata = new HashMap<>(); + metadata.put("foo", "bar"); + SimpleVectorStoreContent storeContent = new SimpleVectorStoreContent("1", "hello, how are you?", metadata); + SimpleVectorStore.Similarity similarity = new SimpleVectorStore.Similarity(storeContent, 0.6d); + Document document = similarity.getDocument(); + assertThat(document).isNotNull(); + assertThat(document.getId()).isEqualTo("1"); + assertThat(document.getContent()).isEqualTo("hello, how are you?"); + assertThat(document.getMetadata().get("foo")).isEqualTo("bar"); + } + +}