From 51006728a238d091c7b518ec4975bbc2aeac1887 Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Sun, 10 Nov 2024 14:11:13 +0100 Subject: [PATCH 1/2] Modular RAG - Query Analysis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Query Analysis * Introduce Query Analysis Module * Define QueryTransformer API and TranslationQueryTransformer implementation * Define QueryExpander API and MultiQueryExpander implementation * Support QueryTransformer in RetrievalAugmentationAdvisor (support for QueryExpander will be in the next PR together with the needed DocumentFuser API). Improvements * Refine Retrieval and Augmentation Modules for increased robustness * Expand test coverage for both modules * Define clone() method for ChatClient.Builder Tests * Introduce “spring-ai-integration-tests” for full-fledged integration tests * Add integration tests for RAG modules * Add integration tests for RAG advisor Relates to #gh-1603 Signed-off-by: Thomas Vitale --- pom.xml | 2 + .../ai/chat/client/ChatClient.java | 2 + .../chat/client/DefaultChatClientBuilder.java | 4 + .../advisor/RetrievalAugmentationAdvisor.java | 71 +++++-- .../ai/rag/analysis/package-info.java | 29 +++ .../query/expansion/MultiQueryExpander.java | 176 ++++++++++++++++++ .../query/expansion/QueryExpander.java | 53 ++++++ .../query/expansion}/package-info.java | 8 +- .../transformation/QueryTransformer.java | 48 +++++ .../TranslationQueryTransformer.java | 137 ++++++++++++++ .../query/transformation/package-info.java | 25 +++ .../ContextualQueryAugmentor.java | 17 +- .../ai/rag/augmentation/QueryAugmentor.java | 3 +- .../ai/rag/augmentation/package-info.java | 9 +- .../springframework/ai/rag/package-info.java | 12 +- .../ai/rag/retrieval/package-info.java | 6 +- .../{source => search}/DocumentRetriever.java | 25 ++- .../VectorStoreDocumentRetriever.java | 10 +- .../ai/rag/retrieval/search/package-info.java | 25 +++ .../client/DefaultChatClientBuilderTests.java | 17 ++ .../RetrievalAugmentationAdvisorTests.java | 30 ++- .../expansion/MultiQueryExpanderTests.java | 70 +++++++ .../TranslationQueryTransformerTests.java | 73 ++++++++ .../ContextualQueryAugmentorTests.java | 18 ++ .../VectorStoreDocumentRetrieverTests.java | 12 +- spring-ai-integration-tests/pom.xml | 96 ++++++++++ .../ai/integration/tests/TestApplication.java | 30 +++ .../tests/TestcontainersConfiguration.java | 37 ++++ .../RetrievalAugmentationAdvisorIT.java | 140 ++++++++++++++ .../query/expansion/MultiQueryExpanderIT.java | 90 +++++++++ .../TranslationQueryTransformerIT.java | 58 ++++++ .../ContextualQueryAugmentorIT.java | 89 +++++++++ .../VectorStoreDocumentRetrieverIT.java | 111 +++++++++++ .../src/test/resources/application.yml | 16 ++ .../resources/documents/knowledge-base.md | 41 ++++ 35 files changed, 1519 insertions(+), 71 deletions(-) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/package-info.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/expansion/MultiQueryExpander.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/expansion/QueryExpander.java rename spring-ai-core/src/main/java/org/springframework/ai/rag/{retrieval/source => analysis/query/expansion}/package-info.java (79%) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/transformation/QueryTransformer.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/transformation/TranslationQueryTransformer.java create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/transformation/package-info.java rename spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/{source => search}/DocumentRetriever.java (63%) rename spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/{source => search}/VectorStoreDocumentRetriever.java (92%) create mode 100644 spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/search/package-info.java create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/rag/analysis/query/expansion/MultiQueryExpanderTests.java create mode 100644 spring-ai-core/src/test/java/org/springframework/ai/rag/analysis/query/transformation/TranslationQueryTransformerTests.java rename spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/{source => search}/VectorStoreDocumentRetrieverTests.java (96%) create mode 100644 spring-ai-integration-tests/pom.xml create mode 100644 spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestApplication.java create mode 100644 spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestcontainersConfiguration.java create mode 100644 spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java create mode 100644 spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/analysis/query/expansion/MultiQueryExpanderIT.java create mode 100644 spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/analysis/query/transformation/TranslationQueryTransformerIT.java create mode 100644 spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/augmentation/ContextualQueryAugmentorIT.java create mode 100644 spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/retrieval/search/VectorStoreDocumentRetrieverIT.java create mode 100644 spring-ai-integration-tests/src/test/resources/application.yml create mode 100644 spring-ai-integration-tests/src/test/resources/documents/knowledge-base.md diff --git a/pom.xml b/pom.xml index 1c19b99f362..1fefcf04130 100644 --- a/pom.xml +++ b/pom.xml @@ -123,6 +123,8 @@ spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai spring-ai-spring-boot-starters/spring-ai-starter-zhipuai spring-ai-spring-boot-starters/spring-ai-starter-moonshot + + spring-ai-integration-tests diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java index d53e998967a..b6a5642a7c2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/ChatClient.java @@ -289,6 +289,8 @@ Builder defaultFunction(String name, String description, Builder defaultToolContext(Map toolContext); + Builder clone(); + ChatClient build(); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java index 4ae3833d868..7f4fbdbff17 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/DefaultChatClientBuilder.java @@ -71,6 +71,10 @@ public ChatClient build() { return new DefaultChatClient(this.defaultRequest); } + public Builder clone() { + return this.defaultRequest.mutate(); + } + public Builder defaultAdvisors(Advisor... advisors) { this.defaultRequest.advisors(advisors); return this; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java index 70c49ff6a70..4474e87a40a 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java @@ -16,11 +16,14 @@ package org.springframework.ai.chat.client.advisor; +import java.util.ArrayList; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.Predicate; +import org.springframework.ai.rag.analysis.query.transformation.QueryTransformer; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.scheduler.Schedulers; @@ -37,13 +40,13 @@ import org.springframework.ai.rag.Query; import org.springframework.ai.rag.augmentation.ContextualQueryAugmentor; import org.springframework.ai.rag.augmentation.QueryAugmentor; -import org.springframework.ai.rag.retrieval.source.DocumentRetriever; +import org.springframework.ai.rag.retrieval.search.DocumentRetriever; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.StringUtils; /** - * This advisor implements common Retrieval Augmented Generation (RAG) flows using the + * Advisor that implements common Retrieval Augmented Generation (RAG) flows using the * building blocks defined in the {@link org.springframework.ai.rag} package and following * the Modular RAG Architecture. *

@@ -55,10 +58,12 @@ * @see arXiv:2407.21059 * @see arXiv:2312.10997 */ -public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { +public final class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor { public static final String DOCUMENT_CONTEXT = "rag_document_context"; + private final List queryTransformers; + private final DocumentRetriever documentRetriever; private final QueryAugmentor queryAugmentor; @@ -67,9 +72,12 @@ public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAr private final int order; - public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable QueryAugmentor queryAugmentor, - @Nullable Boolean protectFromBlocking, @Nullable Integer order) { + public RetrievalAugmentationAdvisor(List queryTransformers, DocumentRetriever documentRetriever, + @Nullable QueryAugmentor queryAugmentor, @Nullable Boolean protectFromBlocking, @Nullable Integer order) { + Assert.notNull(queryTransformers, "queryTransformers cannot be null"); + Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements"); Assert.notNull(documentRetriever, "documentRetriever cannot be null"); + this.queryTransformers = queryTransformers; this.documentRetriever = documentRetriever; this.queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor.builder().build(); this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false; @@ -119,30 +127,45 @@ private AdvisedRequest before(AdvisedRequest request) { Map context = new HashMap<>(request.adviseContext()); // 0. Create a query from the user text and parameters. - Query query = new Query(new PromptTemplate(request.userText(), request.userParams()).render()); + Query originalQuery = new Query(new PromptTemplate(request.userText(), request.userParams()).render()); + + // 1. Transform original user query based on a chain of query transformers. + Query transformedQuery = originalQuery; + for (var queryTransformer : queryTransformers) { + transformedQuery = queryTransformer.apply(transformedQuery); + } - // 1. Retrieve similar documents for the original query. - List documents = this.documentRetriever.retrieve(query); + // 2. Retrieve similar documents for the original query. + List documents = this.documentRetriever.retrieve(transformedQuery); context.put(DOCUMENT_CONTEXT, documents); - // 2. Augment user query with the contextual data. - Query augmentedQuery = this.queryAugmentor.augment(query, documents); + // 3. Augment user query with the document contextual data. + Query augmentedQuery = this.queryAugmentor.augment(transformedQuery, documents); return AdvisedRequest.from(request).withUserText(augmentedQuery.text()).withAdviseContext(context).build(); } private AdvisedResponse after(AdvisedResponse advisedResponse) { - ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response()); + ChatResponse.Builder chatResponseBuilder; + if (advisedResponse.response() == null) { + chatResponseBuilder = ChatResponse.builder(); + } + else { + chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response()); + } chatResponseBuilder.withMetadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT)); return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext()); } private Predicate onFinishReason() { - return advisedResponse -> advisedResponse.response() - .getResults() - .stream() - .anyMatch(result -> result != null && result.getMetadata() != null - && StringUtils.hasText(result.getMetadata().getFinishReason())); + return advisedResponse -> { + ChatResponse chatResponse = advisedResponse.response(); + return chatResponse != null && chatResponse.getResults() != null + && chatResponse.getResults() + .stream() + .anyMatch(result -> result != null && result.getMetadata() != null + && StringUtils.hasText(result.getMetadata().getFinishReason())); + }; } @Override @@ -157,6 +180,8 @@ public int getOrder() { public static final class Builder { + private final List queryTransformers = new ArrayList<>(); + private DocumentRetriever documentRetriever; private QueryAugmentor queryAugmentor; @@ -168,6 +193,18 @@ public static final class Builder { private Builder() { } + public Builder queryTransformers(List queryTransformers) { + Assert.notNull(queryTransformers, "queryTransformers cannot be null"); + this.queryTransformers.addAll(queryTransformers); + return this; + } + + public Builder queryTransformers(QueryTransformer... queryTransformers) { + Assert.notNull(queryTransformers, "queryTransformers cannot be null"); + this.queryTransformers.addAll(Arrays.asList(queryTransformers)); + return this; + } + public Builder documentRetriever(DocumentRetriever documentRetriever) { this.documentRetriever = documentRetriever; return this; @@ -189,7 +226,7 @@ public Builder order(Integer order) { } public RetrievalAugmentationAdvisor build() { - return new RetrievalAugmentationAdvisor(this.documentRetriever, this.queryAugmentor, + return new RetrievalAugmentationAdvisor(this.queryTransformers, this.documentRetriever, this.queryAugmentor, this.protectFromBlocking, this.order); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/package-info.java new file mode 100644 index 00000000000..ee9deac32ac --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/package-info.java @@ -0,0 +1,29 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * RAG Module: Query Analysis. + *

+ * This package encompasses all components involved in the pre-retrieval phase of a + * retrieval augmented generation flow. Queries are transformed, expanded, or constructed + * so to enhance the effectiveness and accuracy of the subsequent retrieval phase. + */ +@NonNullApi +@NonNullFields +package org.springframework.ai.rag.analysis; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/expansion/MultiQueryExpander.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/expansion/MultiQueryExpander.java new file mode 100644 index 00000000000..2ae3879393a --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/expansion/MultiQueryExpander.java @@ -0,0 +1,176 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag.analysis.query.expansion; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.rag.Query; +import org.springframework.ai.util.PromptAssert; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.CollectionUtils; +import org.springframework.util.StringUtils; + +import java.util.Arrays; +import java.util.List; +import java.util.stream.Collectors; + +/** + * Expander that implements semantic query expansion for retrieval-augmented generation + * flows. It uses a large language model to generate multiple semantically diverse + * variations of an input query to capture different perspectives and improve document + * retrieval coverage. + * + *

+ * Example usage:

{@code
+ * MultiQueryExpander expander = MultiQueryExpander.builder()
+ *    .chatClientBuilder(chatClientBuilder)
+ *    .numberOfQueries(3)
+ *    .build();
+ * List queries = expander.expand(new Query("How to run a Spring Boot app?"));
+ * }
+ * + * @author Thomas Vitale + * @since 1.0.0 + */ +public final class MultiQueryExpander implements QueryExpander { + + private static final Logger logger = LoggerFactory.getLogger(MultiQueryExpander.class); + + private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate(""" + You are an expert at information retrieval and search optimization. + Your task is to generate {number} different versions of the given query. + + Each variant must cover different perspectives or aspects of the topic, + while maintaining the core intent of the original query. The goal is to + expand the search space and improve the chances of finding relevant information. + + Do not explain your choices or add any other text. + Provide the query variants separated by newlines. + + Original query: {query} + + Query variants: + """); + + private static final Boolean DEFAULT_INCLUDE_ORIGINAL = false; + + private static final Integer DEFAULT_NUMBER_OF_QUERIES = 3; + + private final ChatClient chatClient; + + private final PromptTemplate promptTemplate; + + private final boolean includeOriginal; + + private final int numberOfQueries; + + public MultiQueryExpander(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate, + @Nullable Boolean includeOriginal, @Nullable Integer numberOfQueries) { + Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null"); + + this.chatClient = chatClientBuilder.build(); + this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; + this.includeOriginal = includeOriginal != null ? includeOriginal : DEFAULT_INCLUDE_ORIGINAL; + this.numberOfQueries = numberOfQueries != null ? numberOfQueries : DEFAULT_NUMBER_OF_QUERIES; + + PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "number", "query"); + } + + @Override + public List expand(Query query) { + Assert.notNull(query, "query cannot be null"); + + logger.debug("Generating {} query variants", numberOfQueries); + + var response = chatClient.prompt() + .user(user -> user.text(promptTemplate.getTemplate()) + .param("number", numberOfQueries) + .param("query", query.text())) + .call() + .content(); + + if (response == null) { + logger.warn("Query expansion result is null. Returning the input query unchanged."); + return List.of(query); + } + + var queryVariants = Arrays.asList(response.split("\n")); + + if (CollectionUtils.isEmpty(queryVariants) || numberOfQueries != queryVariants.size()) { + logger.warn( + "Query expansion result does not contain the requested {} variants. Returning the input query unchanged.", + numberOfQueries); + return List.of(query); + } + + var queries = queryVariants.stream().filter(StringUtils::hasText).map(Query::new).collect(Collectors.toList()); + + if (includeOriginal) { + logger.debug("Including the original query in the result"); + queries.add(0, query); + } + + return queries; + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private ChatClient.Builder chatClientBuilder; + + private PromptTemplate promptTemplate; + + private Boolean includeOriginal; + + private Integer numberOfQueries; + + private Builder() { + } + + public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) { + this.chatClientBuilder = chatClientBuilder; + return this; + } + + public Builder promptTemplate(PromptTemplate promptTemplate) { + this.promptTemplate = promptTemplate; + return this; + } + + public Builder includeOriginal(Boolean includeOriginal) { + this.includeOriginal = includeOriginal; + return this; + } + + public Builder numberOfQueries(Integer numberOfQueries) { + this.numberOfQueries = numberOfQueries; + return this; + } + + public MultiQueryExpander build() { + return new MultiQueryExpander(chatClientBuilder, promptTemplate, includeOriginal, numberOfQueries); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/expansion/QueryExpander.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/expansion/QueryExpander.java new file mode 100644 index 00000000000..5b16a7818f4 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/expansion/QueryExpander.java @@ -0,0 +1,53 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.rag.analysis.query.expansion; + +import org.springframework.ai.rag.Query; + +import java.util.List; +import java.util.function.Function; + +/** + * A component responsible for expanding the input query into a list of related queries + * based on a specified strategy. These expansions can be used to capture different + * perspectives or to break down complex queries into simpler, more manageable + * sub-queries, thereby improving the retrieval process. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface QueryExpander extends Function> { + + /** + * Expands the given query into a list of related queries according to the implemented + * strategy. + * @param query The original query to be expanded + * @return A list of expanded queries + */ + List expand(Query query); + + /** + * Expands the given query into a list of related queries according to the implemented + * strategy. + * @param query The original query to be expanded + * @return A list of expanded queries + */ + default List apply(Query query) { + return expand(query); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/expansion/package-info.java similarity index 79% rename from spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/package-info.java rename to spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/expansion/package-info.java index 7d65ec54b55..b18934d7216 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/expansion/package-info.java @@ -15,15 +15,11 @@ */ /** - * RAG Sub-Module: Source. - *

- * This package provides the functional building blocks for retrieving documents from a - * data source. + * RAG Component: Query Expansion. */ - @NonNullApi @NonNullFields -package org.springframework.ai.rag.retrieval.source; +package org.springframework.ai.rag.analysis.query.expansion; import org.springframework.lang.NonNullApi; import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/transformation/QueryTransformer.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/transformation/QueryTransformer.java new file mode 100644 index 00000000000..efa97a6cf20 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/transformation/QueryTransformer.java @@ -0,0 +1,48 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.rag.analysis.query.transformation; + +import org.springframework.ai.rag.Query; + +import java.util.function.Function; + +/** + * Component responsible for transforming the input query based on a specified strategy. + * These transformations can be used to enhance the clarity, semantic meaning, or language + * of the query, thereby improving the effectiveness of the retrieval process. + * + * @author Thomas Vitale + * @since 1.0.0 + */ +public interface QueryTransformer extends Function { + + /** + * Transforms the given query according to the implemented strategy. + * @param query The original query to transform + * @return The transformed query + */ + Query transform(Query query); + + /** + * Transforms the given query according to the implemented strategy. + * @param query The original query to transform + * @return The transformed query + */ + default Query apply(Query query) { + return transform(query); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/transformation/TranslationQueryTransformer.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/transformation/TranslationQueryTransformer.java new file mode 100644 index 00000000000..591045e10ae --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/transformation/TranslationQueryTransformer.java @@ -0,0 +1,137 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.rag.analysis.query.transformation; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.prompt.ChatOptionsBuilder; +import org.springframework.ai.chat.prompt.PromptTemplate; +import org.springframework.ai.rag.Query; +import org.springframework.ai.util.PromptAssert; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.util.StringUtils; + +/** + * Transformer that handles translation of the input query to a target language using a + * large language model. It's aimed at optimizing similarity searches by translating a + * query into a language supported by the document store. + * + *

+ * Example usage:

{@code
+ * QueryTransformer transformer = TranslationQueryTransformer.builder()
+ *    .chatClientBuilder(chatClientBuilder)
+ *    .targetLanguage("english")
+ *    .build();
+ * Query transformedQuery = transformer.transform(new Query("Hvad er Danmarks hovedstad?"));
+ * }
+ * + * @author Thomas Vitale + * @since 1.0.0 + */ +public final class TranslationQueryTransformer implements QueryTransformer { + + private static final Logger logger = LoggerFactory.getLogger(TranslationQueryTransformer.class); + + private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate(""" + Given a user query, translate it to {targetLanguage}. + If the query is already in {targetLanguage}, return it unchanged. + If you don't know the language of the query, return it unchanged. + Do not add explanations nor any other text. + + Original query: {query} + + Translated query: + """); + + private final ChatClient chatClient; + + private final PromptTemplate promptTemplate; + + private final String targetLanguage; + + public TranslationQueryTransformer(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate, + String targetLanguage) { + Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null"); + Assert.hasText(targetLanguage, "targetLanguage cannot be null or empty"); + + this.chatClient = chatClientBuilder.build(); + this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE; + this.targetLanguage = targetLanguage; + + PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "targetLanguage", "query"); + } + + @Override + public Query transform(Query query) { + Assert.notNull(query, "query cannot be null"); + + logger.debug("Translating query to target language: {}", targetLanguage); + + var translatedQuery = chatClient.prompt() + .user(user -> user.text(promptTemplate.getTemplate()) + .param("targetLanguage", targetLanguage) + .param("query", query.text())) + .options(ChatOptionsBuilder.builder().withTemperature(0.0).build()) + .call() + .content(); + + if (!StringUtils.hasText(translatedQuery)) { + logger.warn("Query translation result is null/empty. Returning the input query unchanged."); + return query; + } + + return new Query(translatedQuery); + } + + public static Builder builder() { + return new Builder(); + } + + public static final class Builder { + + private ChatClient.Builder chatClientBuilder; + + private PromptTemplate promptTemplate; + + private String targetLanguage; + + private Builder() { + } + + public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) { + this.chatClientBuilder = chatClientBuilder; + return this; + } + + public Builder promptTemplate(PromptTemplate promptTemplate) { + this.promptTemplate = promptTemplate; + return this; + } + + public Builder targetLanguage(String targetLanguage) { + this.targetLanguage = targetLanguage; + return this; + } + + public TranslationQueryTransformer build() { + return new TranslationQueryTransformer(chatClientBuilder, promptTemplate, targetLanguage); + } + + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/transformation/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/transformation/package-info.java new file mode 100644 index 00000000000..a71c508bc74 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/analysis/query/transformation/package-info.java @@ -0,0 +1,25 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * RAG Component: Query Transformation. + */ +@NonNullApi +@NonNullFields +package org.springframework.ai.rag.analysis.query.transformation; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/ContextualQueryAugmentor.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/ContextualQueryAugmentor.java index b4c6bdf3f6b..b77ebd9e41b 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/ContextualQueryAugmentor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/ContextualQueryAugmentor.java @@ -20,6 +20,8 @@ import java.util.Map; import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import org.springframework.ai.chat.prompt.PromptTemplate; import org.springframework.ai.document.Document; import org.springframework.ai.model.Content; @@ -29,14 +31,13 @@ import org.springframework.util.Assert; /** - * Augments the user query with contextual data. + * Augments the user query with contextual data from the content of the provided + * documents. * *

* Example usage:

{@code
  * QueryAugmentor augmentor = ContextualQueryAugmentor.builder()
- *    .promptTemplate(promptTemplate)
- *    .emptyContextPromptTemplate(emptyContextPromptTemplate)
- *    .allowEmptyContext(allowEmptyContext)
+ *    .allowEmptyContext(false)
  *    .build();
  * Query augmentedQuery = augmentor.augment(query, documents);
  * }
@@ -44,7 +45,9 @@ * @author Thomas Vitale * @since 1.0.0 */ -public class ContextualQueryAugmentor implements QueryAugmentor { +public final class ContextualQueryAugmentor implements QueryAugmentor { + + private static final Logger logger = LoggerFactory.getLogger(ContextualQueryAugmentor.class); private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate(""" Context information is below. @@ -92,6 +95,8 @@ public Query augment(Query query, List documents) { Assert.notNull(query, "query cannot be null"); Assert.notNull(documents, "documents cannot be null"); + logger.debug("Augmenting query with contextual data"); + if (documents.isEmpty()) { return augmentQueryWhenEmptyContext(query); } @@ -110,8 +115,10 @@ public Query augment(Query query, List documents) { private Query augmentQueryWhenEmptyContext(Query query) { if (this.allowEmptyContext) { + logger.debug("Empty context is allowed. Returning the original query."); return query; } + logger.debug("Empty context is not allowed. Returning a specific query for empty context."); return new Query(this.emptyContextPromptTemplate.render()); } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/QueryAugmentor.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/QueryAugmentor.java index 97b759e5c19..d7359f69dde 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/QueryAugmentor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/QueryAugmentor.java @@ -23,7 +23,8 @@ import org.springframework.ai.rag.Query; /** - * Component for augmenting a query with contextual data based on a specific strategy. + * Component responsible for augmenting an input query with additional contextual data + * that can be used by a large language model to answer the query. * * @author Thomas Vitale * @since 1.0.0 diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/package-info.java index a82b662c6fc..ededce78e83 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/augmentation/package-info.java @@ -15,12 +15,13 @@ */ /** - * RAG Module: Augmentation. + * RAG Module: Query Augmentation. *

- * This package provides the functional building blocks for augmenting a user query with - * contextual data. + * This package encompasses all components involved in the augmentation phase of a + * retrieval augmented generation flow. The goal of this phase is to enrich the user query + * with additional context that can be used to improve the quality of the generated + * response. */ - @NonNullApi @NonNullFields package org.springframework.ai.rag.augmentation; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/package-info.java index a42ade9d8bb..b7061763599 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/package-info.java @@ -16,20 +16,14 @@ /** * This package contains the core interfaces and classes supporting Retrieval Augmented - * Generation. + * Generation flows. *

- * It's based on the Modular RAG Architecture and provides the necessary building blocks - * to define and execute RAG flows. It includes three levels of abstraction: - *

    - *
  1. Module
  2. - *
  3. Sub-Module
  4. - *
  5. Operator
  6. - *
+ * It's inspired by the Modular RAG Architecture and provides the necessary building + * blocks to define and execute RAG flows. * * @see arXiv:2407.21059 * @see arXiv:2312.10997 */ - @NonNullApi @NonNullFields package org.springframework.ai.rag; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/package-info.java index 9995f15aa7c..87af7f55a7c 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/package-info.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/package-info.java @@ -15,11 +15,11 @@ */ /** - * RAG Module: Retrieval. + * RAG Module: Information Retrieval. *

- * This package includes submodules for handling the retrieval process in RAG flows. + * This package includes submodules for handling the retrieval process in + * retrieval-augmented generation flows. */ - @NonNullApi @NonNullFields package org.springframework.ai.rag.retrieval; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/DocumentRetriever.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/search/DocumentRetriever.java similarity index 63% rename from spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/DocumentRetriever.java rename to spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/search/DocumentRetriever.java index e5adc128168..7982dfe4c97 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/DocumentRetriever.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/search/DocumentRetriever.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.rag.retrieval.source; +package org.springframework.ai.rag.retrieval.search; import java.util.List; import java.util.function.Function; @@ -23,7 +23,8 @@ import org.springframework.ai.rag.Query; /** - * API for retrieving {@link Document}s from an underlying data source. + * Component responsible for retrieving {@link Document}s from an underlying data source, + * such as a search engine, a vector store, a database, or a knowledge graph. * * @author Christian Tzolov * @author Thomas Vitale @@ -32,22 +33,18 @@ public interface DocumentRetriever extends Function> { /** - * Retrieves {@link Document}s from an underlying data source using the given - * {@link Query}. + * Retrieves relevant documents from an underlying data source based on the given + * query. + * @param query The query to use for retrieving documents + * @return The list of relevant documents */ List retrieve(Query query); /** - * Retrieves {@link Document}s from an underlying data source using the given query - * string. - */ - default List retrieve(String query) { - return retrieve(new Query(query)); - } - - /** - * Retrieves {@link Document}s from an underlying data source using the given - * {@link Query}. + * Retrieves relevant documents from an underlying data source based on the given + * query. + * @param query The query to use for retrieving documents + * @return The list of relevant documents */ default List apply(Query query) { return retrieve(query); diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/VectorStoreDocumentRetriever.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetriever.java similarity index 92% rename from spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/VectorStoreDocumentRetriever.java rename to spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetriever.java index 1d2415767fc..3fd6aa074a4 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/source/VectorStoreDocumentRetriever.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetriever.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.rag.retrieval.source; +package org.springframework.ai.rag.retrieval.search; import java.util.List; import java.util.function.Supplier; @@ -28,7 +28,7 @@ import org.springframework.util.Assert; /** - * A document retriever that uses a vector store to search for documents. It supports + * Document retriever that uses a vector store to search for documents. It supports * filtering based on metadata, similarity threshold, and top-k results. * *

@@ -39,15 +39,13 @@ * .topK(5) * .filterExpression(filterExpression) * .build(); - * List documents = retriever.retrieve("example query"); + * List documents = retriever.retrieve(new Query("example query")); * } * * @author Thomas Vitale * @since 1.0.0 - * @see VectorStore - * @see Filter.Expression */ -public class VectorStoreDocumentRetriever implements DocumentRetriever { +public final class VectorStoreDocumentRetriever implements DocumentRetriever { private final VectorStore vectorStore; diff --git a/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/search/package-info.java b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/search/package-info.java new file mode 100644 index 00000000000..961f18ddae4 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/rag/retrieval/search/package-info.java @@ -0,0 +1,25 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/** + * RAG Component: Document Search. + */ +@NonNullApi +@NonNullFields +package org.springframework.ai.rag.retrieval.search; + +import org.springframework.lang.NonNullApi; +import org.springframework.lang.NonNullFields; diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java index d17998e706f..4bb321f85d9 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/DefaultChatClientBuilderTests.java @@ -22,7 +22,9 @@ import org.springframework.ai.chat.model.ChatModel; import org.springframework.core.io.ClassPathResource; +import org.springframework.test.util.ReflectionTestUtils; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.mockito.Mockito.mock; @@ -33,6 +35,21 @@ */ class DefaultChatClientBuilderTests { + @Test + void whenCloneBuilder() { + var chatModel = mock(ChatModel.class); + var originalBuilder = new DefaultChatClientBuilder(chatModel); + originalBuilder.defaultSystem("first instructions"); + var clonedBuilder = (DefaultChatClientBuilder) originalBuilder.clone(); + originalBuilder.defaultSystem("second instructions"); + + assertThat(clonedBuilder).isNotSameAs(originalBuilder); + var clonedBuilderRequestSpec = (DefaultChatClient.DefaultChatClientRequestSpec) ReflectionTestUtils + .getField(clonedBuilder, "defaultRequest"); + assertThat(clonedBuilderRequestSpec).isNotNull(); + assertThat(clonedBuilderRequestSpec.getSystemText()).isEqualTo("first instructions"); + } + @Test void whenChatModelIsNullThenThrows() { assertThatThrownBy(() -> new DefaultChatClientBuilder(null)).isInstanceOf(IllegalArgumentException.class) diff --git a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java index 56c5ea81cc9..220bb0a9477 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisorTests.java @@ -29,7 +29,8 @@ import org.springframework.ai.chat.prompt.Prompt; import org.springframework.ai.document.Document; import org.springframework.ai.rag.Query; -import org.springframework.ai.rag.retrieval.source.DocumentRetriever; +import org.springframework.ai.rag.analysis.query.transformation.QueryTransformer; +import org.springframework.ai.rag.retrieval.search.DocumentRetriever; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -43,6 +44,33 @@ */ class RetrievalAugmentationAdvisorTests { + @Test + void whenQueryTransformerListIsNullThenThrow() { + assertThatThrownBy(() -> RetrievalAugmentationAdvisor.builder() + .queryTransformers((List) null) + .documentRetriever(mock(DocumentRetriever.class)) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("queryTransformers cannot be null"); + } + + @Test + void whenQueryTransformerArrayIsNullThenThrow() { + assertThatThrownBy(() -> RetrievalAugmentationAdvisor.builder() + .queryTransformers((QueryTransformer[]) null) + .documentRetriever(mock(DocumentRetriever.class)) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("queryTransformers cannot be null"); + } + + @Test + void whenQueryTransformersContainNullElementsThenThrow() { + assertThatThrownBy(() -> RetrievalAugmentationAdvisor.builder() + .queryTransformers(mock(QueryTransformer.class), null) + .documentRetriever(mock(DocumentRetriever.class)) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("queryTransformers cannot contain null elements"); + } + @Test void whenDocumentRetrieverIsNullThenThrow() { assertThatThrownBy(() -> RetrievalAugmentationAdvisor.builder().documentRetriever(null).build()) diff --git a/spring-ai-core/src/test/java/org/springframework/ai/rag/analysis/query/expansion/MultiQueryExpanderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/rag/analysis/query/expansion/MultiQueryExpanderTests.java new file mode 100644 index 00000000000..e9a5739bc96 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/rag/analysis/query/expansion/MultiQueryExpanderTests.java @@ -0,0 +1,70 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.rag.analysis.query.expansion; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.prompt.PromptTemplate; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link MultiQueryExpander}. + * + * @author Thomas Vitale + */ +class MultiQueryExpanderTests { + + @Test + void whenChatClientBuilderIsNullThenThrow() { + assertThatThrownBy(() -> MultiQueryExpander.builder().chatClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("chatClientBuilder cannot be null"); + } + + @Test + void whenQueryIsNullThenThrow() { + QueryExpander queryExpander = MultiQueryExpander.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .build(); + assertThatThrownBy(() -> queryExpander.expand(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("query cannot be null"); + } + + @Test + void whenPromptHasMissingNumberPlaceholderThenThrow() { + PromptTemplate customPromptTemplate = new PromptTemplate("You are the boss. Original query: {query}"); + assertThatThrownBy(() -> MultiQueryExpander.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .promptTemplate(customPromptTemplate) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("The following placeholders must be present in the prompt template") + .hasMessageContaining("number"); + } + + @Test + void whenPromptHasMissingQueryPlaceholderThenThrow() { + PromptTemplate customPromptTemplate = new PromptTemplate("You are the boss. Number of queries: {number}"); + assertThatThrownBy(() -> MultiQueryExpander.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .promptTemplate(customPromptTemplate) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("The following placeholders must be present in the prompt template") + .hasMessageContaining("query"); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/rag/analysis/query/transformation/TranslationQueryTransformerTests.java b/spring-ai-core/src/test/java/org/springframework/ai/rag/analysis/query/transformation/TranslationQueryTransformerTests.java new file mode 100644 index 00000000000..df8c0272ce9 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/rag/analysis/query/transformation/TranslationQueryTransformerTests.java @@ -0,0 +1,73 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.rag.analysis.query.transformation; + +import org.junit.jupiter.api.Test; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.prompt.PromptTemplate; + +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.mockito.Mockito.mock; + +/** + * Unit tests for {@link TranslationQueryTransformer}. + * + * @author Thomas Vitale + */ +class TranslationQueryTransformerTests { + + @Test + void whenChatClientBuilderIsNullThenThrow() { + assertThatThrownBy(() -> TranslationQueryTransformer.builder().chatClientBuilder(null).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("chatClientBuilder cannot be null"); + } + + @Test + void whenQueryIsNullThenThrow() { + QueryTransformer queryTransformer = TranslationQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .targetLanguage("italian") + .build(); + assertThatThrownBy(() -> queryTransformer.transform(null)).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("query cannot be null"); + } + + @Test + void whenPromptHasMissingTargetLanguagePlaceholderThenThrow() { + PromptTemplate customPromptTemplate = new PromptTemplate("Translate {query}"); + assertThatThrownBy(() -> TranslationQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .targetLanguage("italian") + .promptTemplate(customPromptTemplate) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("The following placeholders must be present in the prompt template") + .hasMessageContaining("targetLanguage"); + } + + @Test + void whenPromptHasMissingQueryPlaceholderThenThrow() { + PromptTemplate customPromptTemplate = new PromptTemplate("Translate to {targetLanguage}"); + assertThatThrownBy(() -> TranslationQueryTransformer.builder() + .chatClientBuilder(mock(ChatClient.Builder.class)) + .targetLanguage("italian") + .promptTemplate(customPromptTemplate) + .build()).isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("The following placeholders must be present in the prompt template") + .hasMessageContaining("query"); + } + +} diff --git a/spring-ai-core/src/test/java/org/springframework/ai/rag/augmentation/ContextualQueryAugmentorTests.java b/spring-ai-core/src/test/java/org/springframework/ai/rag/augmentation/ContextualQueryAugmentorTests.java index 7d9cb112f31..60d3ac278c2 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/rag/augmentation/ContextualQueryAugmentorTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/rag/augmentation/ContextualQueryAugmentorTests.java @@ -35,6 +35,24 @@ */ class ContextualQueryAugmentorTests { + @Test + void whenPromptHasMissingContextPlaceholderThenThrow() { + PromptTemplate customPromptTemplate = new PromptTemplate("You are the boss. Query: {query}"); + assertThatThrownBy(() -> ContextualQueryAugmentor.builder().promptTemplate(customPromptTemplate).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("The following placeholders must be present in the prompt template") + .hasMessageContaining("context"); + } + + @Test + void whenPromptHasMissingQueryPlaceholderThenThrow() { + PromptTemplate customPromptTemplate = new PromptTemplate("You are the boss. Context: {context}"); + assertThatThrownBy(() -> ContextualQueryAugmentor.builder().promptTemplate(customPromptTemplate).build()) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("The following placeholders must be present in the prompt template") + .hasMessageContaining("query"); + } + @Test void whenQueryIsNullThenThrow() { QueryAugmentor augmenter = ContextualQueryAugmentor.builder().build(); diff --git a/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/source/VectorStoreDocumentRetrieverTests.java b/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetrieverTests.java similarity index 96% rename from spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/source/VectorStoreDocumentRetrieverTests.java rename to spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetrieverTests.java index 4d6771f6b95..645bd51cddf 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/source/VectorStoreDocumentRetrieverTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/rag/retrieval/search/VectorStoreDocumentRetrieverTests.java @@ -14,7 +14,7 @@ * limitations under the License. */ -package org.springframework.ai.rag.retrieval.source; +package org.springframework.ai.rag.retrieval.search; import java.util.List; import java.util.Map; @@ -41,6 +41,8 @@ /** * Unit tests for {@link VectorStoreDocumentRetriever}. + * + * @author Thomas Vitale */ class VectorStoreDocumentRetrieverTests { @@ -61,7 +63,7 @@ void searchRequestParameters() { .filterExpression(new Filter.Expression(EQ, new Filter.Key("location"), new Filter.Value("Rivendell"))) .build(); - documentRetriever.retrieve("query"); + documentRetriever.retrieve(new Query("query")); var searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class); verify(mockVectorStore).similaritySearch(searchRequestCaptor.capture()); @@ -85,11 +87,11 @@ void dynamicFilterExpressions() { .build(); TenantContextHolder.setTenantIdentifier("tenant1"); - documentRetriever.retrieve("query"); + documentRetriever.retrieve(new Query("query")); TenantContextHolder.clear(); TenantContextHolder.setTenantIdentifier("tenant2"); - documentRetriever.retrieve("query"); + documentRetriever.retrieve(new Query("query")); TenantContextHolder.clear(); var searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class); @@ -120,7 +122,7 @@ void defaultValuesAreAppliedWhenNotSpecified() { var mockVectorStore = mock(VectorStore.class); var documentRetriever = VectorStoreDocumentRetriever.builder().vectorStore(mockVectorStore).build(); - documentRetriever.retrieve("test query"); + documentRetriever.retrieve(new Query("test query")); var searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class); verify(mockVectorStore).similaritySearch(searchRequestCaptor.capture()); diff --git a/spring-ai-integration-tests/pom.xml b/spring-ai-integration-tests/pom.xml new file mode 100644 index 00000000000..b2e38804b75 --- /dev/null +++ b/spring-ai-integration-tests/pom.xml @@ -0,0 +1,96 @@ + + + + + 4.0.0 + + org.springframework.ai + spring-ai + 1.0.0-SNAPSHOT + + spring-ai-integration-tests + jar + Spring AI Integration Tests + Integration tests for Spring AI + https://github.com/spring-projects/spring-ai + + + https://github.com/spring-projects/spring-ai + git://github.com/spring-projects/spring-ai.git + git@github.com:spring-projects/spring-ai.git + + + + 17 + 17 + true + true + + + + + org.springframework.boot + spring-boot-starter-web + test + + + + org.springframework.ai + spring-ai-openai-spring-boot-starter + ${project.parent.version} + test + + + + org.springframework.ai + spring-ai-pgvector-store-spring-boot-starter + ${project.parent.version} + test + + + + org.springframework.ai + spring-ai-markdown-document-reader + ${project.parent.version} + test + + + + org.springframework.boot + spring-boot-testcontainers + test + + + + org.testcontainers + postgresql + test + + + + org.springframework.boot + spring-boot-starter-test + + + com.vaadin.external.google + android-json + + + test + + + diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestApplication.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestApplication.java new file mode 100644 index 00000000000..5f43a0887cb --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestApplication.java @@ -0,0 +1,30 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.integration.tests; + +import org.springframework.boot.autoconfigure.SpringBootApplication; +import org.springframework.context.annotation.Import; + +/** + * Test application for integration tests. + * + * @author Thomas Vitale + */ +@SpringBootApplication +@Import(TestcontainersConfiguration.class) +public class TestApplication { + +} diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestcontainersConfiguration.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestcontainersConfiguration.java new file mode 100644 index 00000000000..b34aeb86b40 --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/TestcontainersConfiguration.java @@ -0,0 +1,37 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.integration.tests; + +import org.springframework.boot.test.context.TestConfiguration; +import org.springframework.boot.testcontainers.service.connection.ServiceConnection; +import org.springframework.context.annotation.Bean; +import org.testcontainers.containers.PostgreSQLContainer; + +/** + * Test configuration for Testcontainers-based Dev Services. + * + * @author Thomas Vitale + */ +@TestConfiguration(proxyBeanMethods = false) +class TestcontainersConfiguration { + + @Bean + @ServiceConnection + PostgreSQLContainer pgvectorContainer() { + return new PostgreSQLContainer<>("pgvector/pgvector:pg17"); + } + +} diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java new file mode 100644 index 00000000000..a25fdec5ba9 --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/client/advisor/RetrievalAugmentationAdvisorIT.java @@ -0,0 +1,140 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.integration.tests.client.advisor; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.chat.client.advisor.RetrievalAugmentationAdvisor; +import org.springframework.ai.chat.model.ChatResponse; +import org.springframework.ai.document.Document; +import org.springframework.ai.document.DocumentReader; +import org.springframework.ai.evaluation.EvaluationRequest; +import org.springframework.ai.evaluation.EvaluationResponse; +import org.springframework.ai.evaluation.RelevancyEvaluator; +import org.springframework.ai.integration.tests.TestApplication; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.rag.analysis.query.transformation.TranslationQueryTransformer; +import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever; +import org.springframework.ai.reader.markdown.MarkdownDocumentReader; +import org.springframework.ai.reader.markdown.config.MarkdownDocumentReaderConfig; +import org.springframework.ai.vectorstore.PgVectorStore; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.test.context.SpringBootTest; +import org.springframework.core.io.Resource; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link RetrievalAugmentationAdvisor}. + * + * @author Thomas Vitale + */ +@SpringBootTest(classes = TestApplication.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +class RetrievalAugmentationAdvisorIT { + + private List knowledgeBaseDocuments; + + @Autowired + OpenAiChatModel openAiChatModel; + + @Autowired + PgVectorStore pgVectorStore; + + @Value("${classpath:documents/knowledge-base.md}") + Resource knowledgeBaseResource; + + @BeforeEach + void setUp() { + DocumentReader markdownReader = new MarkdownDocumentReader(knowledgeBaseResource, + MarkdownDocumentReaderConfig.defaultConfig()); + knowledgeBaseDocuments = markdownReader.read(); + pgVectorStore.add(knowledgeBaseDocuments); + } + + @AfterEach + void tearDown() { + pgVectorStore.delete(knowledgeBaseDocuments.stream().map(Document::getId).toList()); + } + + @Test + void ragBasic() { + String question = "Where does the adventure of Anacletus and Birba take place?"; + + RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder() + .documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(pgVectorStore).build()) + .build(); + + ChatResponse chatResponse = ChatClient.builder(openAiChatModel) + .build() + .prompt(question) + .advisors(ragAdvisor) + .call() + .chatResponse(); + + assertThat(chatResponse).isNotNull(); + + String response = chatResponse.getResult().getOutput().getContent(); + System.out.println(response); + assertThat(response).containsIgnoringCase("Highlands"); + + evaluateRelevancy(question, chatResponse); + } + + @Test + void ragWithTranslation() { + String question = "Hvor finder Anacletus og Birbas eventyr sted?"; + + RetrievalAugmentationAdvisor ragAdvisor = RetrievalAugmentationAdvisor.builder() + .queryTransformers(TranslationQueryTransformer.builder() + .chatClientBuilder(ChatClient.builder(openAiChatModel)) + .targetLanguage("english") + .build()) + .documentRetriever(VectorStoreDocumentRetriever.builder().vectorStore(pgVectorStore).build()) + .build(); + + ChatResponse chatResponse = ChatClient.builder(openAiChatModel) + .build() + .prompt(question) + .advisors(ragAdvisor) + .call() + .chatResponse(); + + assertThat(chatResponse).isNotNull(); + + String response = chatResponse.getResult().getOutput().getContent(); + System.out.println(response); + assertThat(response).containsIgnoringCase("Highlands"); + + evaluateRelevancy(question, chatResponse); + } + + private void evaluateRelevancy(String question, ChatResponse chatResponse) { + EvaluationRequest evaluationRequest = new EvaluationRequest(question, + chatResponse.getMetadata().get(RetrievalAugmentationAdvisor.DOCUMENT_CONTEXT), + chatResponse.getResult().getOutput().getContent()); + RelevancyEvaluator evaluator = new RelevancyEvaluator(ChatClient.builder(openAiChatModel)); + EvaluationResponse evaluationResponse = evaluator.evaluate(evaluationRequest); + assertThat(evaluationResponse.isPass()).isTrue(); + } + +} diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/analysis/query/expansion/MultiQueryExpanderIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/analysis/query/expansion/MultiQueryExpanderIT.java new file mode 100644 index 00000000000..3057f1c662e --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/analysis/query/expansion/MultiQueryExpanderIT.java @@ -0,0 +1,90 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.integration.tests.rag.analysis.query.expansion; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.integration.tests.TestApplication; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.analysis.query.expansion.MultiQueryExpander; +import org.springframework.ai.rag.analysis.query.expansion.QueryExpander; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link MultiQueryExpander}. + * + * @author Thomas Vitale + */ +@SpringBootTest(classes = TestApplication.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +class MultiQueryExpanderIT { + + @Autowired + OpenAiChatModel openAiChatModel; + + @Test + void whenExpanderWithDefaults() { + Query query = new Query("What is the weather in Rome?"); + QueryExpander queryExpander = MultiQueryExpander.builder() + .chatClientBuilder(ChatClient.builder(openAiChatModel)) + .build(); + + List queries = queryExpander.apply(query); + + assertThat(queries).isNotNull(); + queries.forEach(System.out::println); + assertThat(queries).hasSize(3); + } + + @Test + void whenExpanderWithCustomQueryNumber() { + Query query = new Query("What is the weather in Rome?"); + QueryExpander queryExpander = MultiQueryExpander.builder() + .chatClientBuilder(ChatClient.builder(openAiChatModel)) + .numberOfQueries(4) + .build(); + + List queries = queryExpander.apply(query); + + assertThat(queries).isNotNull(); + queries.forEach(System.out::println); + assertThat(queries).hasSize(4); + } + + @Test + void whenExpanderWithOriginalQueryIncluded() { + Query query = new Query("What is the weather in Rome?"); + QueryExpander queryExpander = MultiQueryExpander.builder() + .chatClientBuilder(ChatClient.builder(openAiChatModel)) + .numberOfQueries(3) + .includeOriginal(true) + .build(); + + List queries = queryExpander.apply(query); + + assertThat(queries).isNotNull(); + queries.forEach(System.out::println); + assertThat(queries).hasSize(4); + } + +} diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/analysis/query/transformation/TranslationQueryTransformerIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/analysis/query/transformation/TranslationQueryTransformerIT.java new file mode 100644 index 00000000000..106a036820e --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/analysis/query/transformation/TranslationQueryTransformerIT.java @@ -0,0 +1,58 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.integration.tests.rag.analysis.query.transformation; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.chat.client.ChatClient; +import org.springframework.ai.integration.tests.TestApplication; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.analysis.query.transformation.QueryTransformer; +import org.springframework.ai.rag.analysis.query.transformation.TranslationQueryTransformer; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link TranslationQueryTransformer}. + * + * @author Thomas Vitale + */ +@SpringBootTest(classes = TestApplication.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +class TranslationQueryTransformerIT { + + @Autowired + OpenAiChatModel openAiChatModel; + + @Test + void whenTransformerWithDefaults() { + Query query = new Query("Hvad er Danmarks hovedstad?"); + QueryTransformer queryTransformer = TranslationQueryTransformer.builder() + .chatClientBuilder(ChatClient.builder(openAiChatModel)) + .targetLanguage("english") + .build(); + + Query transformedQuery = queryTransformer.apply(query); + + assertThat(transformedQuery).isNotNull(); + System.out.println(transformedQuery); + assertThat(transformedQuery.text()).containsIgnoringCase("Denmark").containsIgnoringCase("capital"); + } + +} diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/augmentation/ContextualQueryAugmentorIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/augmentation/ContextualQueryAugmentorIT.java new file mode 100644 index 00000000000..26bccac5727 --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/augmentation/ContextualQueryAugmentorIT.java @@ -0,0 +1,89 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.integration.tests.rag.augmentation; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.document.Document; +import org.springframework.ai.integration.tests.TestApplication; +import org.springframework.ai.openai.OpenAiChatModel; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.augmentation.ContextualQueryAugmentor; +import org.springframework.ai.rag.augmentation.QueryAugmentor; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import java.util.List; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * Integration tests for {@link ContextualQueryAugmentor}. + * + * @author Thomas Vitale + */ +@SpringBootTest(classes = TestApplication.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +class ContextualQueryAugmentorIT { + + @Autowired + OpenAiChatModel openAiChatModel; + + @Test + void whenContextIsProvided() { + QueryAugmentor queryAugmentor = ContextualQueryAugmentor.builder().build(); + Query query = new Query("What is Iorek's dream?"); + List documents = List + .of(new Document("Iorek was a little polar bear who lived in the Arctic circle."), new Document( + "Iorek loved to explore the snowy landscape and dreamt of one day going on an adventure around the North Pole.")); + + Query augmentedQuery = queryAugmentor.augment(query, documents); + String response = openAiChatModel.call(augmentedQuery.text()); + + assertThat(response).isNotEmpty(); + System.out.println(response); + assertThat(response).containsIgnoringCase("North Pole"); + assertThat(response).doesNotContainIgnoringCase("context"); + assertThat(response).doesNotContainIgnoringCase("information"); + } + + @Test + void whenAllowEmptyContext() { + QueryAugmentor queryAugmentor = ContextualQueryAugmentor.builder().build(); + Query query = new Query("What is Iorek's dream?"); + List documents = List.of(); + Query augmentedQuery = queryAugmentor.augment(query, documents); + String response = openAiChatModel.call(augmentedQuery.text()); + + assertThat(response).isNotEmpty(); + System.out.println(response); + assertThat(response).containsIgnoringCase("Iorek"); + } + + @Test + void whenNotAllowEmptyContext() { + QueryAugmentor queryAugmentor = ContextualQueryAugmentor.builder().allowEmptyContext(false).build(); + Query query = new Query("What is Iorek's dream?"); + List documents = List.of(); + Query augmentedQuery = queryAugmentor.augment(query, documents); + String response = openAiChatModel.call(augmentedQuery.text()); + + assertThat(response).isNotEmpty(); + System.out.println(response); + assertThat(response).doesNotContainIgnoringCase("Iorek"); + } + +} diff --git a/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/retrieval/search/VectorStoreDocumentRetrieverIT.java b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/retrieval/search/VectorStoreDocumentRetrieverIT.java new file mode 100644 index 00000000000..65b8dec3621 --- /dev/null +++ b/spring-ai-integration-tests/src/test/java/org/springframework/ai/integration/tests/rag/retrieval/search/VectorStoreDocumentRetrieverIT.java @@ -0,0 +1,111 @@ +/* + * Copyright 2023-2024 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.ai.integration.tests.rag.retrieval.search; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; +import org.springframework.ai.document.Document; +import org.springframework.ai.integration.tests.TestApplication; +import org.springframework.ai.rag.Query; +import org.springframework.ai.rag.retrieval.search.DocumentRetriever; +import org.springframework.ai.rag.retrieval.search.VectorStoreDocumentRetriever; +import org.springframework.ai.vectorstore.PgVectorStore; +import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.context.SpringBootTest; + +import java.util.List; +import java.util.Map; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; + +/** + * Integration tests for {@link VectorStoreDocumentRetriever}. + * + * @author Thomas Vitale + */ +@SpringBootTest(classes = TestApplication.class) +@EnabledIfEnvironmentVariable(named = "OPENAI_API_KEY", matches = ".*") +class VectorStoreDocumentRetrieverIT { + + private static final Map documents = Map.of("1", new Document( + "Anacletus was a majestic snowy owl with unusually bright golden eyes and distinctive black speckles across his wings.", + Map.of("location", "Whispering Woods")), "2", + new Document( + "Anacletus made his home in an ancient hollow oak tree deep within the Whispering Woods, where local villagers often heard his haunting calls at midnight.", + Map.of("location", "Whispering Woods")), + "3", + new Document( + "Despite being a nocturnal hunter like other owls, Anacletus had developed a peculiar habit of collecting shiny objects, especially lost coins and jewelry that glinted in the moonlight.", + Map.of()), + "4", + new Document( + "Birba was a plump Siamese cat with mismatched eyes - one blue and one green - who spent her days lounging on velvet cushions and judging everyone with a perpetual look of disdain.", + Map.of("location", "Alfea"))); + + @Autowired + PgVectorStore pgVectorStore; + + @BeforeEach + void setUp() { + pgVectorStore.add(List.copyOf(documents.values())); + } + + @AfterEach + void tearDown() { + pgVectorStore.delete(documents.values().stream().map(Document::getId).toList()); + } + + @Test + void withFilter() { + DocumentRetriever documentRetriever = VectorStoreDocumentRetriever.builder() + .vectorStore(pgVectorStore) + .similarityThreshold(0.50) + .topK(3) + .filterExpression( + new Filter.Expression(EQ, new Filter.Key("location"), new Filter.Value("Whispering Woods"))) + .build(); + + List retrievedDocuments = documentRetriever.retrieve(new Query("Who is Anacletus?")); + + assertThat(retrievedDocuments).hasSize(2); + assertThat(retrievedDocuments).anyMatch(document -> document.getId().equals(documents.get("1").getId())); + assertThat(retrievedDocuments).anyMatch(document -> document.getId().equals(documents.get("2").getId())); + + retrievedDocuments = documentRetriever.retrieve(new Query("Who is Birba?")); + assertThat(retrievedDocuments).noneMatch(document -> document.getId().equals(documents.get("4").getId())); + } + + @Test + void withNoFilter() { + DocumentRetriever documentRetriever = VectorStoreDocumentRetriever.builder() + .vectorStore(pgVectorStore) + .similarityThreshold(0.50) + .topK(3) + .build(); + + List retrievedDocuments = documentRetriever.retrieve(new Query("Who is Anacletus?")); + + assertThat(retrievedDocuments).hasSize(3); + assertThat(retrievedDocuments).anyMatch(document -> document.getId().equals(documents.get("1").getId())); + assertThat(retrievedDocuments).anyMatch(document -> document.getId().equals(documents.get("2").getId())); + assertThat(retrievedDocuments).anyMatch(document -> document.getId().equals(documents.get("3").getId())); + } + +} diff --git a/spring-ai-integration-tests/src/test/resources/application.yml b/spring-ai-integration-tests/src/test/resources/application.yml new file mode 100644 index 00000000000..b6aeb131b95 --- /dev/null +++ b/spring-ai-integration-tests/src/test/resources/application.yml @@ -0,0 +1,16 @@ +spring: + main: + web-application-type: none + ai: + openai: + chat: + options: + model: gpt-4o-mini + embedding: + options: + model: text-embedding-ada-002 + retry: + max-attempts: 3 + vectorstore: + pgvector: + initialize-schema: true \ No newline at end of file diff --git a/spring-ai-integration-tests/src/test/resources/documents/knowledge-base.md b/spring-ai-integration-tests/src/test/resources/documents/knowledge-base.md new file mode 100644 index 00000000000..8553110921e --- /dev/null +++ b/spring-ai-integration-tests/src/test/resources/documents/knowledge-base.md @@ -0,0 +1,41 @@ +# Anacletus and Birba's Quest for the Loch of the Stars + +## Chapter 1: The Map and the Adventure + +Once upon a time, in a cozy little cottage nestled at the edge of the Scottish Highlands, lived an owl named Anacletus and a curious cat named Birba. Anacletus was wise and careful, always reading maps and planning things thoroughly, while Birba was lively and adventurous, always ready to chase after the next interesting thing. Despite their differences, they were the best of friends and loved going on little adventures together. + +## Chapter 2: The Journey Begins + +One sunny morning, Anacletus showed Birba an old, crinkled map he’d found in the attic. “Look, Birba,” he said, pointing with his feathery wing. “This map leads to the legendary Loch of the Stars. They say it shines brighter than any other lake at night.” Birba’s eyes sparkled with excitement. “Oh, we have to go there!” she meowed. So, they packed a small bag with snacks, a compass, and a flashlight, and off they went, eager to find the legendary loch. + +## Chapter 3: The Highland Adventure + +Their journey began with a climb up the rolling hills covered in purple heather. Anacletus flapped his wings, soaring ahead to scout for any obstacles, while Birba trotted along below, her nose sniffing the air for interesting scents. Soon, they came across a bubbling brook. Anacletus carefully flew over it, but Birba hesitated. “Just a little jump!” Anacletus called out. With a deep breath, Birba leaped and landed safely on the other side. She purred proudly, and they continued on their way. + +## Chapter 4: The Highland Cows and the Hidden Path + +As they ventured deeper into the Highlands, they stumbled upon a herd of curious Highland cows with long, shaggy hair. The cows mooed softly, and one of them named Fergus approached. “Where are you two headed?” Fergus asked. “We’re searching for the Loch of the Stars!” Anacletus replied. Fergus nodded knowingly and pointed his nose north. “Follow the path by the big stones, and it will lead you closer to the loch,” he said. Thanking Fergus, they set off again, Birba occasionally stopping to bat at the fluttering butterflies along the way. + +## Chapter 5: The Mysterious Forest and the Deer Family + +The day wore on, and they soon found themselves in a mysterious forest. Tall, ancient pine trees surrounded them, casting long shadows. “Stay close, Birba,” Anacletus whispered, his wise eyes scanning for any sign of danger. But Birba had already darted after a flicker of light, thinking it was a firefly. Anacletus sighed and followed her until they came to a hidden glade where a family of deer grazed quietly. The smallest fawn looked up and gave them a curious nod before they moved along. + +## Chapter 6: The Loch of the Stars + +After a while, the sun began to set, painting the sky in shades of pink and gold. Anacletus decided it was a good time to rest. They found a cozy hollow at the base of a tree, where they shared the snacks they’d packed. Birba munched on her fish treats while Anacletus nibbled on a biscuit. “Do you think we’ll find the Loch of the Stars?” Birba asked, her eyes twinkling. “I think so,” Anacletus replied with a wise smile. “We’re getting closer.” + +## Chapter 7: The Shimmering Loch + +As night fell, they finally reached the top of a hill where they could see a shimmering light in the distance. “Look, Birba!” Anacletus hooted excitedly. There, nestled among the hills, was the Loch of the Stars, gleaming like a sky full of stars. The two friends hurried down to the water’s edge, marveling at how the loch sparkled under the moonlight, casting a gentle glow all around. + +## Chapter 8: The Magic of the Loch + +Birba dipped a curious paw into the water, causing ripples that sent stars dancing across the surface. “It’s beautiful!” she gasped. Anacletus nodded, his heart filled with awe. They spent the night by the loch, watching the shimmering stars reflected in the water, feeling as though they were surrounded by magic. + +## Chapter 9: The Journey Home + +When dawn broke, the shimmering loch returned to its quiet, glassy calm. With a satisfied yawn, Birba stretched and said, “That was the best adventure yet.” Anacletus agreed, feeling a warmth in his feathers as they turned back toward home, carrying memories of the Loch of the Stars in their hearts. + +## Chapter 10: The End of the Adventure + +And as they made their way back to their cozy cottage, they already started dreaming of their next big adventure—because Anacletus and Birba knew that the Scottish Highlands held endless wonders for those who dared to explore. From f9022294f0aafa3212a4399654ba71177188d4a5 Mon Sep 17 00:00:00 2001 From: Thomas Vitale Date: Wed, 13 Nov 2024 20:24:06 +0100 Subject: [PATCH 2/2] Modular RAG - Query Analysis MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Query Analysis * Introduce Query Analysis Module * Define QueryTransformer API and TranslationQueryTransformer implementation * Define QueryExpander API and MultiQueryExpander implementation * Support QueryTransformer in RetrievalAugmentationAdvisor (support for QueryExpander will be in the next PR together with the needed DocumentFuser API). Improvements * Refine Retrieval and Augmentation Modules for increased robustness * Expand test coverage for both modules * Define clone() method for ChatClient.Builder Tests * Introduce “spring-ai-integration-tests” for full-fledged integration tests * Add integration tests for RAG modules * Add integration tests for RAG advisor Relates to #gh-1603 Signed-off-by: Thomas Vitale --- .../ai/chat/client/advisor/RetrievalAugmentationAdvisor.java | 2 +- spring-ai-integration-tests/src/test/resources/application.yml | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java index 4474e87a40a..5d6a9676ddc 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/chat/client/advisor/RetrievalAugmentationAdvisor.java @@ -80,7 +80,7 @@ public RetrievalAugmentationAdvisor(List queryTransformers, Do this.queryTransformers = queryTransformers; this.documentRetriever = documentRetriever; this.queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor.builder().build(); - this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false; + this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : true; this.order = order != null ? order : 0; } diff --git a/spring-ai-integration-tests/src/test/resources/application.yml b/spring-ai-integration-tests/src/test/resources/application.yml index b6aeb131b95..278a7b2c356 100644 --- a/spring-ai-integration-tests/src/test/resources/application.yml +++ b/spring-ai-integration-tests/src/test/resources/application.yml @@ -3,6 +3,7 @@ spring: web-application-type: none ai: openai: + api-key: ${OPENAI_API_KEY} chat: options: model: gpt-4o-mini