Skip to content

Commit

Permalink
Chroma Vector Store: replace RestTemplate by RestClient
Browse files Browse the repository at this point in the history
 - replace RestTemplate by  RestClient but default to SimpleClientHttpRequestFactory as Chroma seems to have issues with HTTP2.
 - update the documentation.
 - update the ghcr.io/chroma-core/chroma version to 0.5.0. Fix the withBasicAuthCredentials.

Co-authored-by: Christian Tzolov <ctzolov@vmware.com>
  • Loading branch information
eddumelendez and tzolov committed Jun 16, 2024
1 parent 46e4784 commit 704dec2
Show file tree
Hide file tree
Showing 8 changed files with 151 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -205,19 +205,20 @@ TIP: Refer to the xref:getting-started.adoc#dependency-management[Dependency Man

=== Sample Code

Create a `RestTemplate` instance with proper ChromaDB authorization configurations and Use it to create a `ChromaApi` instance:
Create a `RestClient.Builder` instance with proper ChromaDB authorization configurations and Use it to create a `ChromaApi` instance:

[source,java]
----
@Bean
public RestTemplate restTemplate() {
return new RestTemplate();
public RestClient.Builder builder() {
return RestClient.builder().requestFactory(new SimpleClientHttpRequestFactory());
}
@Bean
public ChromaApi chromaApi(RestTemplate restTemplate) {
public ChromaApi chromaApi(RestClient.Builder restClientBuilder) {
String chromaUrl = "http://localhost:8000";
ChromaApi chromaApi = new ChromaApi(chromaUrl, restTemplate);
ChromaApi chromaApi = new ChromaApi(chromaUrl, restClientBuilder);
return chromaApi;
}
----
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
*/
package org.springframework.ai.autoconfigure.vectorstore.chroma;

import com.fasterxml.jackson.databind.ObjectMapper;

import org.springframework.ai.chroma.ChromaApi;
import org.springframework.ai.embedding.EmbeddingModel;
import org.springframework.ai.vectorstore.ChromaVectorStore;
Expand All @@ -25,15 +23,18 @@
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.util.StringUtils;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.client.RestClient;

import com.fasterxml.jackson.databind.ObjectMapper;

/**
* @author Christian Tzolov
* @author Eddú Meléndez
*/
@AutoConfiguration
@ConditionalOnClass({ EmbeddingModel.class, RestTemplate.class, ChromaVectorStore.class, ObjectMapper.class })
@ConditionalOnClass({ EmbeddingModel.class, RestClient.class, ChromaVectorStore.class, ObjectMapper.class })
@EnableConfigurationProperties({ ChromaApiProperties.class, ChromaVectorStoreProperties.class })
public class ChromaVectorStoreAutoConfiguration {

Expand All @@ -45,18 +46,18 @@ PropertiesChromaConnectionDetails chromaConnectionDetails(ChromaApiProperties pr

@Bean
@ConditionalOnMissingBean
public RestTemplate restTemplate() {
return new RestTemplate();
public RestClient.Builder builder() {
return RestClient.builder().requestFactory(new SimpleClientHttpRequestFactory());
}

@Bean
@ConditionalOnMissingBean
public ChromaApi chromaApi(ChromaApiProperties apiProperties, RestTemplate restTemplate,
public ChromaApi chromaApi(ChromaApiProperties apiProperties, RestClient.Builder restClientBuilder,
ChromaConnectionDetails connectionDetails) {

String chromaUrl = String.format("%s:%s", connectionDetails.getHost(), connectionDetails.getPort());

var chromaApi = new ChromaApi(chromaUrl, restTemplate, new ObjectMapper());
var chromaApi = new ChromaApi(chromaUrl, restClientBuilder, new ObjectMapper());

if (StringUtils.hasText(apiProperties.getKeyToken())) {
chromaApi.withKeyToken(apiProperties.getKeyToken());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
public class ChromaVectorStoreAutoConfigurationIT {

@Container
static ChromaDBContainer chroma = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.4.15");
static ChromaDBContainer chroma = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.5.0");

private final ApplicationContextRunner contextRunner = new ApplicationContextRunner()
.withConfiguration(AutoConfigurations.of(ChromaVectorStoreAutoConfiguration.class))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,49 +19,55 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;

import org.springframework.ai.chroma.ChromaApi.QueryRequest.Include;
import org.springframework.http.HttpEntity;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.http.client.support.BasicAuthenticationInterceptor;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.client.RestClient;

import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;

/**
* Single-class Chroma API implementation based on the (unofficial) Chroma REST API.
*
* @author Christian Tzolov
* @author Eddú Meléndez
*/
public class ChromaApi {

// Regular expression pattern that looks for a message inside the ValueError(...).
private static Pattern VALUE_ERROR_PATTERN = Pattern.compile("ValueError\\('([^']*)'\\)");

private final String baseUrl;

private final RestTemplate restTemplate;
private RestClient restClient;

private final ObjectMapper objectMapper;

private String keyToken;

public ChromaApi(String baseUrl, RestTemplate restTemplate) {
this(baseUrl, restTemplate, new ObjectMapper());
public ChromaApi(String baseUrl) {
this(baseUrl, RestClient.builder().requestFactory(new SimpleClientHttpRequestFactory()), new ObjectMapper());
}

public ChromaApi(String baseUrl, RestTemplate restTemplate, ObjectMapper objectMapper) {
this.baseUrl = baseUrl;
this.restTemplate = restTemplate;
public ChromaApi(String baseUrl, RestClient.Builder restClientBuilder) {
this(baseUrl, restClientBuilder, new ObjectMapper());
}

public ChromaApi(String baseUrl, RestClient.Builder restClientBuilder, ObjectMapper objectMapper) {
Consumer<HttpHeaders> defaultHeaders = headers -> {
headers.setContentType(MediaType.APPLICATION_JSON);
};
this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build();
this.objectMapper = objectMapper;
}

Expand All @@ -82,7 +88,9 @@ public ChromaApi withKeyToken(String keyToken) {
* @param password Credentials password.
*/
public ChromaApi withBasicAuthCredentials(String username, String password) {
this.restTemplate.getInterceptors().add(new BasicAuthenticationInterceptor(username, password));
this.restClient = this.restClient.mutate()
.requestInterceptor(new BasicAuthenticationInterceptor(username, password))
.build();
return this;
}

Expand Down Expand Up @@ -265,9 +273,23 @@ public List<Embedding> toEmbeddingResponseList(QueryResponse queryResponse) {

public Collection createCollection(CreateCollectionRequest createCollectionRequest) {

return this.restTemplate
.exchange(this.baseUrl + "/api/v1/collections", HttpMethod.POST,
this.getHttpEntityFor(createCollectionRequest), Collection.class)
return this.restClient.post()
.uri("/api/v1/collections")
.headers(this::httpHeaders)
.body(createCollectionRequest)
.retrieve()
.toEntity(Collection.class)
.getBody();
}

public Map<String, Object> createCollection2(CreateCollectionRequest createCollectionRequest) {

return this.restClient.post()
.uri("/api/v1/collections")
.headers(this::httpHeaders)
.body(createCollectionRequest)
.retrieve()
.toEntity(Map.class)
.getBody();
}

Expand All @@ -278,16 +300,21 @@ public Collection createCollection(CreateCollectionRequest createCollectionReque
*/
public void deleteCollection(String collectionName) {

this.restTemplate.exchange(this.baseUrl + "/api/v1/collections/{collection_name}", HttpMethod.DELETE,
new HttpEntity<>(httpHeaders()), Void.class, collectionName);
this.restClient.delete()
.uri("/api/v1/collections/{collection_name}", collectionName)
.headers(this::httpHeaders)
.retrieve()
.toBodilessEntity();
}

public Collection getCollection(String collectionName) {

try {
return this.restTemplate
.exchange(this.baseUrl + "/api/v1/collections/{collection_name}", HttpMethod.GET,
new HttpEntity<>(httpHeaders()), Collection.class, collectionName)
return this.restClient.get()
.uri("/api/v1/collections/{collection_name}", collectionName)
.headers(this::httpHeaders)
.retrieve()
.toEntity(Collection.class)
.getBody();
}
catch (HttpServerErrorException e) {
Expand All @@ -305,9 +332,11 @@ private static class CollectionList extends ArrayList<Collection> {

public List<Collection> listCollections() {

return this.restTemplate
.exchange(this.baseUrl + "/api/v1/collections", HttpMethod.GET, new HttpEntity<>(httpHeaders()),
CollectionList.class)
return this.restClient.get()
.uri("/api/v1/collections")
.headers(this::httpHeaders)
.retrieve()
.toEntity(CollectionList.class)
.getBody();
}

Expand All @@ -317,41 +346,55 @@ public List<Collection> listCollections() {

public void upsertEmbeddings(String collectionId, AddEmbeddingsRequest embedding) {

this.restTemplate
.exchange(this.baseUrl + "/api/v1/collections/{collection_id}/upsert", HttpMethod.POST,
this.getHttpEntityFor(embedding), Boolean.class, collectionId)
.getBody();
this.restClient.post()
.uri("/api/v1/collections/{collection_id}/upsert", collectionId)
.headers(this::httpHeaders)
.body(embedding)
.retrieve()
.toBodilessEntity();
}

public List<String> deleteEmbeddings(String collectionId, DeleteEmbeddingsRequest deleteRequest) {

return this.restTemplate
.exchange(this.baseUrl + "/api/v1/collections/{collection_id}/delete", HttpMethod.POST,
this.getHttpEntityFor(deleteRequest), List.class, collectionId)
return this.restClient.post()
.uri("/api/v1/collections/{collection_id}/delete", collectionId)
.headers(this::httpHeaders)
.body(deleteRequest)
.retrieve()
.toEntity(new ParameterizedTypeReference<List<String>>() {
})
.getBody();
}

public Long countEmbeddings(String collectionId) {

return this.restTemplate
.exchange(this.baseUrl + "/api/v1/collections/{collection_id}/count", HttpMethod.GET,
new HttpEntity<>(httpHeaders()), Long.class, collectionId)
return this.restClient.get()
.uri("/api/v1/collections/{collection_id}/count", collectionId)
.headers(this::httpHeaders)
.retrieve()
.toEntity(Long.class)
.getBody();
}

public QueryResponse queryCollection(String collectionId, QueryRequest queryRequest) {

return this.restTemplate
.exchange(this.baseUrl + "/api/v1/collections/{collection_id}/query", HttpMethod.POST,
this.getHttpEntityFor(queryRequest), QueryResponse.class, collectionId)
return this.restClient.post()
.uri("/api/v1/collections/{collection_id}/query", collectionId)
.headers(this::httpHeaders)
.body(queryRequest)
.retrieve()
.toEntity(QueryResponse.class)
.getBody();
}

public GetEmbeddingResponse getEmbeddings(String collectionId, GetEmbeddingsRequest getEmbeddingsRequest) {

return this.restTemplate
.exchange(this.baseUrl + "/api/v1/collections/{collection_id}/get", HttpMethod.POST,
this.getHttpEntityFor(getEmbeddingsRequest), GetEmbeddingResponse.class, collectionId)
return this.restClient.post()
.uri("/api/v1/collections/{collection_id}/get", collectionId)
.headers(this::httpHeaders)
.body(getEmbeddingsRequest)
.retrieve()
.toEntity(GetEmbeddingResponse.class)
.getBody();
}

Expand All @@ -365,17 +408,10 @@ public Map<String, Object> where(String text) {
}
}

private <T> HttpEntity<T> getHttpEntityFor(T body) {
return new HttpEntity<>(body, httpHeaders());
}

private HttpHeaders httpHeaders() {
HttpHeaders headers = new HttpHeaders();
headers.setContentType(MediaType.APPLICATION_JSON);
private void httpHeaders(HttpHeaders headers) {
if (StringUtils.hasText(this.keyToken)) {
headers.setBearerAuth(this.keyToken);
}
return headers;
}

private String getValueErrorMessage(String logString) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,24 @@
*/
package org.springframework.ai.chroma;

import static org.assertj.core.api.Assertions.assertThat;

import java.util.List;
import java.util.Map;

import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.testcontainers.chromadb.ChromaDBContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.springframework.ai.chroma.ChromaApi.AddEmbeddingsRequest;
import org.springframework.ai.chroma.ChromaApi.Collection;
import org.springframework.ai.chroma.ChromaApi.GetEmbeddingsRequest;
import org.springframework.ai.chroma.ChromaApi.QueryRequest;
import org.springframework.web.client.RestTemplate;

import static org.assertj.core.api.Assertions.assertThat;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.SpringBootConfiguration;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.context.annotation.Bean;
import org.testcontainers.chromadb.ChromaDBContainer;
import org.testcontainers.junit.jupiter.Container;
import org.testcontainers.junit.jupiter.Testcontainers;

/**
* @author Christian Tzolov
Expand All @@ -45,7 +43,7 @@
public class ChromaApiIT {

@Container
static ChromaDBContainer chromaContainer = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.4.22");
static ChromaDBContainer chromaContainer = new ChromaDBContainer("ghcr.io/chroma-core/chroma:0.4.12");

@Autowired
ChromaApi chroma;
Expand Down Expand Up @@ -179,13 +177,8 @@ public void testQueryWhere() {
public static class Config {

@Bean
public RestTemplate restTemplate() {
return new RestTemplate();
}

@Bean
public ChromaApi chromaApi(RestTemplate restTemplate) {
return new ChromaApi(chromaContainer.getEndpoint(), restTemplate);
public ChromaApi chromaApi() {
return new ChromaApi(chromaContainer.getEndpoint());
}

}
Expand Down
Loading

0 comments on commit 704dec2

Please sign in to comment.