Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancing Elasticsearch vector store implementation #592

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,7 @@
<pgvector.version>0.1.4</pgvector.version>
<sap.hanadb.version>2.20.11</sap.hanadb.version>
<postgresql.version>42.7.2</postgresql.version>
<elasticsearch-java.version>8.13.3</elasticsearch-java.version>
<milvus.version>2.3.4</milvus.version>
<pinecone.version>0.8.0</pinecone.version>
<fastjson.version>2.0.46</fastjson.version>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,18 @@ Properties starting with the `spring.ai.vectorstore.elasticsearch.*` prefix are

|`spring.ai.vectorstore.elasticsearch.index-name` | The name of the index to store the vectors. | spring-ai-document-index
|`spring.ai.vectorstore.elasticsearch.dimensions` | The number of dimensions in the vector. | 1536
|`spring.ai.vectorstore.elasticsearch.dense-vector-indexing` | Whether to use dense vector indexing. | true
|`spring.ai.vectorstore.elasticsearch.similarity` | The similarity function to use. | `cosine`
|`spring.ai.vectorstore.elasticsearch.initialize-schema`| whether to initialize the required schema | `false`
|===

The following similarity functions are available:

* cosine
* l2_norm
* dot_product

More details about each in the https://www.elastic.co/guide/en/elasticsearch/reference/master/dense-vector.html#dense-vector-params[Elasticsearch Documentation] on dense vectors.

== Metadata Filtering

You can leverage the generic, portable xref:api/vectordbs.adoc#metadata-filters[metadata filters] with Elasticsearch as well.
Expand Down Expand Up @@ -214,10 +221,11 @@ Read the link:https://www.elastic.co/guide/en/elasticsearch/client/java-api-clie
----
@Bean
public RestClient restClient() {
RestClientBuilder builder = RestClient.builder(new HttpHost("<host>", 9200, "http"));
Header[] defaultHeaders = new Header[] { new BasicHeader("Authorization", "Basic <encoded username and password>") };
builder.setDefaultHeaders(defaultHeaders);
return builder.build();
RestClient.builder(new HttpHost("<host>", 9200, "http"))
.setDefaultHeaders(new Header[]{
new BasicHeader("Authorization", "Basic <encoded username and password>")
})
.build();
}
----

Expand Down
7 changes: 7 additions & 0 deletions spring-ai-spring-boot-autoconfigure/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@
<optional>true</optional>
</dependency>

<!-- Elasticsearch Vector Store-->
<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-elasticsearch-store</artifactId>
Expand All @@ -281,6 +282,12 @@
<optional>true</optional>
</dependency>

<dependency>
<groupId>co.elastic.clients</groupId>
<artifactId>elasticsearch-java</artifactId>
<version>${elasticsearch-java.version}</version>
</dependency>

<!-- test dependencies -->

<dependency>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,7 @@ ElasticsearchVectorStore vectorStore(ElasticsearchVectorStoreProperties properti
if (properties.getDimensions() != null) {
elasticsearchVectorStoreOptions.setDimensions(properties.getDimensions());
}
if (properties.isDenseVectorIndexing() != null) {
elasticsearchVectorStoreOptions.setDenseVectorIndexing(properties.isDenseVectorIndexing());
}
if (StringUtils.hasText(properties.getSimilarity())) {
if (properties.getSimilarity() != null) {
elasticsearchVectorStoreOptions.setSimilarity(properties.getSimilarity());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.ai.autoconfigure.vectorstore.elasticsearch;

import org.springframework.ai.autoconfigure.CommonVectorStoreProperties;
import org.springframework.ai.vectorstore.SimilarityFunction;
import org.springframework.boot.context.properties.ConfigurationProperties;

/**
Expand All @@ -37,15 +38,10 @@ public class ElasticsearchVectorStoreProperties extends CommonVectorStorePropert
*/
private Integer dimensions;

/**
* Whether to use dense vector indexing.
*/
private Boolean denseVectorIndexing;

/**
* The similarity function to use.
*/
private String similarity;
private SimilarityFunction similarity;

public String getIndexName() {
return this.indexName;
Expand All @@ -63,19 +59,11 @@ public void setDimensions(Integer dimensions) {
this.dimensions = dimensions;
}

public Boolean isDenseVectorIndexing() {
return denseVectorIndexing;
}

public void setDenseVectorIndexing(Boolean denseVectorIndexing) {
this.denseVectorIndexing = denseVectorIndexing;
}

public String getSimilarity() {
public SimilarityFunction getSimilarity() {
return similarity;
}

public void setSimilarity(String similarity) {
public void setSimilarity(SimilarityFunction similarity) {
this.similarity = similarity;
}

Expand Down
2 changes: 1 addition & 1 deletion vector-stores/spring-ai-elasticsearch-store/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
<dependency>
<groupId>co.elastic.clients</groupId>
<artifactId>elasticsearch-java</artifactId>
<version>${elasticsearch-java.version}</version>
</dependency>

<!-- TESTING -->
Expand All @@ -45,7 +46,6 @@
<scope>test</scope>
</dependency>


<dependency>
<groupId>org.springframework.ai</groupId>
<artifactId>spring-ai-test</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,13 @@
package org.springframework.ai.vectorstore;

import co.elastic.clients.elasticsearch.ElasticsearchClient;
import co.elastic.clients.elasticsearch._types.mapping.DenseVectorProperty;
import co.elastic.clients.elasticsearch._types.mapping.Property;
import co.elastic.clients.elasticsearch._types.query_dsl.Query;
import co.elastic.clients.elasticsearch.core.BulkRequest;
import co.elastic.clients.elasticsearch.core.BulkResponse;
import co.elastic.clients.elasticsearch.core.SearchResponse;
import co.elastic.clients.elasticsearch.core.bulk.BulkResponseItem;
import co.elastic.clients.elasticsearch.core.search.Hit;
import co.elastic.clients.elasticsearch.indices.CreateIndexResponse;
import co.elastic.clients.json.JsonData;
import co.elastic.clients.json.jackson.JacksonJsonpMapper;
import co.elastic.clients.transport.endpoints.BooleanResponse;
import co.elastic.clients.transport.rest_client.RestClientTransport;
import com.fasterxml.jackson.databind.DeserializationFeature;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -46,16 +42,17 @@
import java.util.Optional;
import java.util.stream.Collectors;

import static java.lang.Math.sqrt;
import static org.springframework.ai.vectorstore.SimilarityFunction.l2_norm;

/**
* @author Jemin Huh
* @author Wei Jiang
* @author Laura Trotta
* @since 1.0.0
*/
public class ElasticsearchVectorStore implements VectorStore, InitializingBean {

// divided by 2 to get score in the range [0, 1]
public static final String COSINE_SIMILARITY_FUNCTION = "(cosineSimilarity(params.query_vector, 'embedding') + 1.0) / 2";

private static final Logger logger = LoggerFactory.getLogger(ElasticsearchVectorStore.class);

private final EmbeddingModel embeddingModel;
Expand All @@ -66,8 +63,6 @@ public class ElasticsearchVectorStore implements VectorStore, InitializingBean {

private final FilterExpressionConverter filterExpressionConverter;

private String similarityFunction;

private final boolean initializeSchema;

public ElasticsearchVectorStore(RestClient restClient, EmbeddingModel embeddingModel, boolean initializeSchema) {
Expand All @@ -84,30 +79,22 @@ public ElasticsearchVectorStore(ElasticsearchVectorStoreOptions options, RestCli
this.embeddingModel = embeddingModel;
this.options = options;
this.filterExpressionConverter = new ElasticsearchAiSearchFilterExpressionConverter();
// the potential functions for vector fields at
// https://www.elastic.co/guide/en/elasticsearch/reference/current/query-dsl-script-score-query.html#vector-functions
this.similarityFunction = COSINE_SIMILARITY_FUNCTION;
}

public ElasticsearchVectorStore withSimilarityFunction(String similarityFunction) {
this.similarityFunction = similarityFunction;
return this;
}

@Override
public void add(List<Document> documents) {
BulkRequest.Builder builkRequestBuilder = new BulkRequest.Builder();
BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder();

for (Document document : documents) {
if (Objects.isNull(document.getEmbedding()) || document.getEmbedding().isEmpty()) {
logger.debug("Calling EmbeddingModel for document id = " + document.getId());
document.setEmbedding(this.embeddingModel.embed(document));
}
builkRequestBuilder.operations(op -> op
bulkRequestBuilder.operations(op -> op
.index(idx -> idx.index(this.options.getIndexName()).id(document.getId()).document(document)));
}

BulkResponse bulkRequest = bulkRequest(builkRequestBuilder.build());
BulkResponse bulkRequest = bulkRequest(bulkRequestBuilder.build());

if (bulkRequest.errors()) {
List<BulkResponseItem> bulkResponseItems = bulkRequest.items();
Expand All @@ -121,10 +108,10 @@ public void add(List<Document> documents) {

@Override
public Optional<Boolean> delete(List<String> idList) {
BulkRequest.Builder builkRequestBuilder = new BulkRequest.Builder();
BulkRequest.Builder bulkRequestBuilder = new BulkRequest.Builder();
for (String id : idList)
builkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.options.getIndexName()).id(id)));
return Optional.of(bulkRequest(builkRequestBuilder.build()).errors());
bulkRequestBuilder.operations(op -> op.delete(idx -> idx.index(this.options.getIndexName()).id(id)));
return Optional.of(bulkRequest(bulkRequestBuilder.build()).errors());
}

private BulkResponse bulkRequest(BulkRequest bulkRequest) {
Expand All @@ -139,61 +126,67 @@ private BulkResponse bulkRequest(BulkRequest bulkRequest) {
@Override
public List<Document> similaritySearch(SearchRequest searchRequest) {
Assert.notNull(searchRequest, "The search request must not be null.");
return similaritySearch(this.embeddingModel.embed(searchRequest.getQuery()), searchRequest.getTopK(),
Double.valueOf(searchRequest.getSimilarityThreshold()).floatValue(),
searchRequest.getFilterExpression());
}

public List<Document> similaritySearch(List<Double> embedding, int topK, double similarityThreshold,
Filter.Expression filterExpression) {
return similaritySearch(
new co.elastic.clients.elasticsearch.core.SearchRequest.Builder().index(options.getIndexName())
.query(getElasticsearchSimilarityQuery(embedding, filterExpression))
.size(topK)
.minScore(similarityThreshold)
.build());
}

private Query getElasticsearchSimilarityQuery(List<Double> embedding, Filter.Expression filterExpression) {
return Query.of(queryBuilder -> queryBuilder.scriptScore(scriptScoreQueryBuilder -> scriptScoreQueryBuilder
.query(queryBuilder2 -> queryBuilder2.queryString(queryStringQuerybuilder -> queryStringQuerybuilder
.query(getElasticsearchQueryString(filterExpression))))
.script(scriptBuilder -> scriptBuilder
.inline(inlineScriptBuilder -> inlineScriptBuilder.source(this.similarityFunction)
.params("query_vector", JsonData.of(embedding))))));
}

private String getElasticsearchQueryString(Filter.Expression filterExpression) {
return Objects.isNull(filterExpression) ? "*"
: this.filterExpressionConverter.convertExpression(filterExpression);

}

private List<Document> similaritySearch(co.elastic.clients.elasticsearch.core.SearchRequest searchRequest) {
try {
return this.elasticsearchClient.search(searchRequest, Document.class)
.hits()
.hits()
float threshold = (float) searchRequest.getSimilarityThreshold();
// reverting l2_norm distance to its original value
if (options.getSimilarity().equals(l2_norm)) {
threshold = 1 - threshold;
}
final float finalThreshold = threshold;
List<Float> vectors = this.embeddingModel.embed(searchRequest.getQuery())
.stream()
.map(this::toDocument)
.collect(Collectors.toList());
.map(Double::floatValue)
.toList();

SearchResponse<Document> res = elasticsearchClient.search(
sr -> sr.index(options.getIndexName())
.knn(knn -> knn.queryVector(vectors)
.similarity(finalThreshold)
.k((long) searchRequest.getTopK())
.field("embedding")
.numCandidates((long) (1.5 * searchRequest.getTopK()))
.filter(fl -> fl.queryString(
qs -> qs.query(getElasticsearchQueryString(searchRequest.getFilterExpression()))))),
Document.class);

return res.hits().hits().stream().map(this::toDocument).collect(Collectors.toList());
}
catch (IOException e) {
throw new RuntimeException(e);
}
}

private String getElasticsearchQueryString(Filter.Expression filterExpression) {
return Objects.isNull(filterExpression) ? "*"
: this.filterExpressionConverter.convertExpression(filterExpression);

}

private Document toDocument(Hit<Document> hit) {
Document document = hit.source();
document.getMetadata().put("distance", 1 - hit.score().floatValue());
document.getMetadata().put("distance", calculateDistance(hit.score().floatValue()));
return document;
}

private boolean indexExists() {
// more info on score/distance calculation
// https://www.elastic.co/guide/en/elasticsearch/reference/current/knn-search.html#knn-similarity-search
private float calculateDistance(Float score) {
switch (options.getSimilarity()) {
case l2_norm:
// the returned value of l2_norm is the opposite of the other functions
// (closest to zero means more accurate), so to make it consistent
// with the other functions the reverse is returned applying a "1-"
// to the standard transformation
return (float) (1 - (sqrt((1 / score) - 1)));
// cosine and dot_product
default:
return (2 * score) - 1;
}
}

public boolean indexExists() {
try {
BooleanResponse response = this.elasticsearchClient.indices()
.exists(existRequestBuilder -> existRequestBuilder.index(options.getIndexName()));
return response.value();
return this.elasticsearchClient.indices().exists(ex -> ex.index(options.getIndexName())).value();
}
catch (IOException e) {
throw new RuntimeException(e);
Expand All @@ -203,18 +196,9 @@ private boolean indexExists() {
private CreateIndexResponse createIndexMapping() {
try {
return this.elasticsearchClient.indices()
.create(createIndexBuilder -> createIndexBuilder.index(options.getIndexName())
.mappings(typeMappingBuilder -> {
typeMappingBuilder.properties("embedding",
new Property.Builder()
.denseVector(new DenseVectorProperty.Builder().dims(options.getDimensions())
.similarity(options.getSimilarity())
.index(options.isDenseVectorIndexing())
.build())
.build());

return typeMappingBuilder;
}));
.create(cr -> cr.index(options.getIndexName())
.mappings(map -> map.properties("embedding", p -> p.denseVector(
dv -> dv.similarity(options.getSimilarity().toString()).dims(options.getDimensions())))));
}
catch (IOException e) {
throw new RuntimeException(e);
Expand All @@ -233,4 +217,4 @@ public void afterPropertiesSet() {
}
}

}
}