From e7269a8cbe50a561d498e40cc4081b0004b183aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edd=C3=BA=20Mel=C3=A9ndez?= Date: Sat, 11 May 2024 21:24:16 +0200 Subject: [PATCH] Use RestClient in ChromaApi Use RestClient internally but uses the configuration from RestTemplate. RestTemplate uses SimpleClientHttpRequestFactory meanwhile RestClient uses JdkClientHttpRequestFactory and produces `unsupported upgrade request` when making a request to chroma. --- .../ChromaVectorStoreAutoConfiguration.java | 5 +- .../springframework/ai/chroma/ChromaApi.java | 113 ++++++++++-------- .../ai/chroma/ChromaApiIT.java | 3 +- .../vectorstore/BasicAuthChromaWhereIT.java | 7 +- .../ai/vectorstore/ChromaVectorStoreIT.java | 4 +- .../TokenSecuredChromaWhereIT.java | 5 +- 6 files changed, 80 insertions(+), 57 deletions(-) diff --git a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java index 74a2d3200f..f098daba5f 100644 --- a/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java +++ b/spring-ai-spring-boot-autoconfigure/src/main/java/org/springframework/ai/autoconfigure/vectorstore/chroma/ChromaVectorStoreAutoConfiguration.java @@ -26,6 +26,7 @@ import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; import org.springframework.util.StringUtils; +import org.springframework.web.client.RestClient; import org.springframework.web.client.RestTemplate; /** @@ -56,7 +57,7 @@ public ChromaApi chromaApi(ChromaApiProperties apiProperties, RestTemplate restT String chromaUrl = String.format("%s:%s", connectionDetails.getHost(), connectionDetails.getPort()); - var chromaApi = new ChromaApi(chromaUrl, restTemplate, new ObjectMapper()); + var chromaApi = new ChromaApi(chromaUrl, RestClient.builder(restTemplate), new ObjectMapper()); if (StringUtils.hasText(apiProperties.getKeyToken())) { chromaApi.withKeyToken(apiProperties.getKeyToken()); @@ -75,7 +76,7 @@ public ChromaVectorStore vectorStore(EmbeddingClient embeddingClient, ChromaApi return new ChromaVectorStore(embeddingClient, chromaApi, storeProperties.getCollectionName()); } - private static class PropertiesChromaConnectionDetails implements ChromaConnectionDetails { + static class PropertiesChromaConnectionDetails implements ChromaConnectionDetails { private final ChromaApiProperties properties; diff --git a/vector-stores/spring-ai-chroma/src/main/java/org/springframework/ai/chroma/ChromaApi.java b/vector-stores/spring-ai-chroma/src/main/java/org/springframework/ai/chroma/ChromaApi.java index 04474cd532..c0df59d418 100644 --- a/vector-stores/spring-ai-chroma/src/main/java/org/springframework/ai/chroma/ChromaApi.java +++ b/vector-stores/spring-ai-chroma/src/main/java/org/springframework/ai/chroma/ChromaApi.java @@ -19,6 +19,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Consumer; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -27,41 +28,42 @@ import com.fasterxml.jackson.databind.ObjectMapper; import org.springframework.ai.chroma.ChromaApi.QueryRequest.Include; -import org.springframework.http.HttpEntity; +import org.springframework.core.ParameterizedTypeReference; import org.springframework.http.HttpHeaders; -import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; +import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.client.support.BasicAuthenticationInterceptor; import org.springframework.util.CollectionUtils; import org.springframework.util.StringUtils; import org.springframework.web.client.HttpServerErrorException; -import org.springframework.web.client.RestTemplate; +import org.springframework.web.client.RestClient; /** * Single-class Chroma API implementation based on the (unofficial) Chroma REST API. * * @author Christian Tzolov + * @author EddĂș MelĂ©ndez */ public class ChromaApi { // Regular expression pattern that looks for a message inside the ValueError(...). private static Pattern VALUE_ERROR_PATTERN = Pattern.compile("ValueError\\('([^']*)'\\)"); - private final String baseUrl; - - private final RestTemplate restTemplate; + private final RestClient restClient; private final ObjectMapper objectMapper; private String keyToken; - public ChromaApi(String baseUrl, RestTemplate restTemplate) { - this(baseUrl, restTemplate, new ObjectMapper()); + public ChromaApi(String baseUrl, RestClient.Builder restClientBuilder) { + this(baseUrl, restClientBuilder, new ObjectMapper()); } - public ChromaApi(String baseUrl, RestTemplate restTemplate, ObjectMapper objectMapper) { - this.baseUrl = baseUrl; - this.restTemplate = restTemplate; + public ChromaApi(String baseUrl, RestClient.Builder restClientBuilder, ObjectMapper objectMapper) { + Consumer defaultHeaders = headers -> { + headers.setContentType(MediaType.APPLICATION_JSON); + }; + this.restClient = restClientBuilder.baseUrl(baseUrl).defaultHeaders(defaultHeaders).build(); this.objectMapper = objectMapper; } @@ -82,7 +84,7 @@ public ChromaApi withKeyToken(String keyToken) { * @param password Credentials password. */ public ChromaApi withBasicAuthCredentials(String username, String password) { - this.restTemplate.getInterceptors().add(new BasicAuthenticationInterceptor(username, password)); + this.restClient.mutate().requestInterceptor(new BasicAuthenticationInterceptor(username, password)); return this; } @@ -265,9 +267,12 @@ public List toEmbeddingResponseList(QueryResponse queryResponse) { public Collection createCollection(CreateCollectionRequest createCollectionRequest) { - return this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections", HttpMethod.POST, - this.getHttpEntityFor(createCollectionRequest), Collection.class) + return this.restClient.post() + .uri("/api/v1/collections") + .headers(this::httpHeaders) + .body(createCollectionRequest) + .retrieve() + .toEntity(Collection.class) .getBody(); } @@ -278,16 +283,21 @@ public Collection createCollection(CreateCollectionRequest createCollectionReque */ public void deleteCollection(String collectionName) { - this.restTemplate.exchange(this.baseUrl + "/api/v1/collections/{collection_name}", HttpMethod.DELETE, - new HttpEntity<>(httpHeaders()), Void.class, collectionName); + this.restClient.delete() + .uri("/api/v1/collections/{collection_name}", collectionName) + .headers(this::httpHeaders) + .retrieve() + .toBodilessEntity(); } public Collection getCollection(String collectionName) { try { - return this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections/{collection_name}", HttpMethod.GET, - new HttpEntity<>(httpHeaders()), Collection.class, collectionName) + return this.restClient.get() + .uri("/api/v1/collections/{collection_name}", collectionName) + .headers(this::httpHeaders) + .retrieve() + .toEntity(Collection.class) .getBody(); } catch (HttpServerErrorException e) { @@ -305,9 +315,11 @@ private static class CollectionList extends ArrayList { public List listCollections() { - return this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections", HttpMethod.GET, new HttpEntity<>(httpHeaders()), - CollectionList.class) + return this.restClient.get() + .uri("/api/v1/collections") + .headers(this::httpHeaders) + .retrieve() + .toEntity(CollectionList.class) .getBody(); } @@ -317,41 +329,55 @@ public List listCollections() { public void upsertEmbeddings(String collectionId, AddEmbeddingsRequest embedding) { - this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections/{collection_id}/upsert", HttpMethod.POST, - this.getHttpEntityFor(embedding), Boolean.class, collectionId) - .getBody(); + this.restClient.post() + .uri("/api/v1/collections/{collection_id}/upsert", collectionId) + .headers(this::httpHeaders) + .body(embedding) + .retrieve() + .toBodilessEntity(); } public List deleteEmbeddings(String collectionId, DeleteEmbeddingsRequest deleteRequest) { - return this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections/{collection_id}/delete", HttpMethod.POST, - this.getHttpEntityFor(deleteRequest), List.class, collectionId) + return this.restClient.post() + .uri("/api/v1/collections/{collection_id}/delete", collectionId) + .headers(this::httpHeaders) + .body(deleteRequest) + .retrieve() + .toEntity(new ParameterizedTypeReference>() { + }) .getBody(); } public Long countEmbeddings(String collectionId) { - return this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections/{collection_id}/count", HttpMethod.GET, - new HttpEntity<>(httpHeaders()), Long.class, collectionId) + return this.restClient.get() + .uri("/api/v1/collections/{collection_id}/count", collectionId) + .headers(this::httpHeaders) + .retrieve() + .toEntity(Long.class) .getBody(); } public QueryResponse queryCollection(String collectionId, QueryRequest queryRequest) { - return this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections/{collection_id}/query", HttpMethod.POST, - this.getHttpEntityFor(queryRequest), QueryResponse.class, collectionId) + return this.restClient.post() + .uri("/api/v1/collections/{collection_id}/query", collectionId) + .headers(this::httpHeaders) + .body(queryRequest) + .retrieve() + .toEntity(QueryResponse.class) .getBody(); } public GetEmbeddingResponse getEmbeddings(String collectionId, GetEmbeddingsRequest getEmbeddingsRequest) { - return this.restTemplate - .exchange(this.baseUrl + "/api/v1/collections/{collection_id}/get", HttpMethod.POST, - this.getHttpEntityFor(getEmbeddingsRequest), GetEmbeddingResponse.class, collectionId) + return this.restClient.post() + .uri("/api/v1/collections/{collection_id}/get", collectionId) + .headers(this::httpHeaders) + .body(getEmbeddingsRequest) + .retrieve() + .toEntity(GetEmbeddingResponse.class) .getBody(); } @@ -365,17 +391,10 @@ public Map where(String text) { } } - private HttpEntity getHttpEntityFor(T body) { - return new HttpEntity<>(body, httpHeaders()); - } - - private HttpHeaders httpHeaders() { - HttpHeaders headers = new HttpHeaders(); - headers.setContentType(MediaType.APPLICATION_JSON); + private void httpHeaders(HttpHeaders headers) { if (StringUtils.hasText(this.keyToken)) { headers.setBearerAuth(this.keyToken); } - return headers; } private String getValueErrorMessage(String logString) { diff --git a/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java b/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java index a53814d569..a35aaaba24 100644 --- a/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java +++ b/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/chroma/ChromaApiIT.java @@ -20,6 +20,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.springframework.web.client.RestClient; import org.testcontainers.chromadb.ChromaDBContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -185,7 +186,7 @@ public RestTemplate restTemplate() { @Bean public ChromaApi chromaApi(RestTemplate restTemplate) { - return new ChromaApi(chromaContainer.getEndpoint(), restTemplate); + return new ChromaApi(chromaContainer.getEndpoint(), RestClient.builder(restTemplate)); } } diff --git a/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java b/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java index a7b9ebf12c..d332998ced 100644 --- a/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java +++ b/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/vectorstore/BasicAuthChromaWhereIT.java @@ -19,6 +19,8 @@ import java.util.Map; import org.junit.jupiter.api.Test; +import org.springframework.web.client.RestClient; +import org.springframework.web.client.RestTemplate; import org.testcontainers.chromadb.ChromaDBContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -31,7 +33,6 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; -import org.springframework.web.client.RestTemplate; import org.testcontainers.utility.MountableFile; import static org.assertj.core.api.Assertions.assertThat; @@ -101,8 +102,8 @@ public RestTemplate restTemplate() { @Bean public ChromaApi chromaApi(RestTemplate restTemplate) { - return new ChromaApi(chromaContainer.getEndpoint(), restTemplate).withBasicAuthCredentials("admin", - "password"); + return new ChromaApi(chromaContainer.getEndpoint(), RestClient.builder(restTemplate)) + .withBasicAuthCredentials("admin", "password"); } @Bean diff --git a/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java b/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java index 9c5e8674ed..6d353686e0 100644 --- a/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java +++ b/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/vectorstore/ChromaVectorStoreIT.java @@ -21,6 +21,7 @@ import java.util.UUID; import org.junit.jupiter.api.Test; +import org.springframework.web.client.RestClient; import org.testcontainers.chromadb.ChromaDBContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -29,7 +30,6 @@ import org.springframework.ai.document.Document; import org.springframework.ai.embedding.EmbeddingClient; import org.springframework.ai.openai.api.OpenAiApi; -import org.springframework.ai.vectorstore.ChromaVectorStore; import org.springframework.ai.openai.OpenAiEmbeddingClient; import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; @@ -209,7 +209,7 @@ public RestTemplate restTemplate() { @Bean public ChromaApi chromaApi(RestTemplate restTemplate) { - return new ChromaApi(chromaContainer.getEndpoint(), restTemplate); + return new ChromaApi(chromaContainer.getEndpoint(), RestClient.builder(restTemplate)); } @Bean diff --git a/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java b/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java index 4953732c1b..b6da839028 100644 --- a/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java +++ b/vector-stores/spring-ai-chroma/src/test/java/org/springframework/ai/vectorstore/TokenSecuredChromaWhereIT.java @@ -19,6 +19,8 @@ import java.util.Map; import org.junit.jupiter.api.Test; +import org.springframework.web.client.RestClient; +import org.springframework.web.client.RestTemplate; import org.testcontainers.chromadb.ChromaDBContainer; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -31,7 +33,6 @@ import org.springframework.boot.SpringBootConfiguration; import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.context.annotation.Bean; -import org.springframework.web.client.RestTemplate; import static org.assertj.core.api.Assertions.assertThat; @@ -132,7 +133,7 @@ public RestTemplate restTemplate() { @Bean public ChromaApi chromaApi(RestTemplate restTemplate) { - var chromaApi = new ChromaApi(chromaContainer.getEndpoint(), restTemplate); + var chromaApi = new ChromaApi(chromaContainer.getEndpoint(), RestClient.builder(restTemplate)); chromaApi.withKeyToken(CHROMA_SERVER_AUTH_CREDENTIALS); return chromaApi; }