Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@
<module>spring-ai-spring-boot-starters/spring-ai-starter-watsonx-ai</module>
<module>spring-ai-spring-boot-starters/spring-ai-starter-zhipuai</module>
<module>spring-ai-spring-boot-starters/spring-ai-starter-moonshot</module>

<module>spring-ai-integration-tests</module>
</modules>

<organization>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ <I, O> Builder defaultFunction(String name, String description,

Builder defaultToolContext(Map<String, Object> toolContext);

Builder clone();

ChatClient build();

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ public ChatClient build() {
return new DefaultChatClient(this.defaultRequest);
}

public Builder clone() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All the other APIs used by ChatClient have similar methods, but they are called mutate() instead of clone(). I wonder if it'd make sense to change them to clone()? It would be aligned with the naming used in other Spring APIs (like RestClient) and also aligned with Java conventions in general.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, we figured the mystery :) The Builder has the "clone()" method and the Client APIs the "mutate()" method.

return this.defaultRequest.mutate();
}

public Builder defaultAdvisors(Advisor... advisors) {
this.defaultRequest.advisors(advisors);
return this;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@

package org.springframework.ai.chat.client.advisor;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Predicate;

import org.springframework.ai.rag.analysis.query.transformation.QueryTransformer;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
Expand All @@ -37,13 +40,13 @@
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.augmentation.ContextualQueryAugmentor;
import org.springframework.ai.rag.augmentation.QueryAugmentor;
import org.springframework.ai.rag.retrieval.source.DocumentRetriever;
import org.springframework.ai.rag.retrieval.search.DocumentRetriever;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;

/**
* This advisor implements common Retrieval Augmented Generation (RAG) flows using the
* Advisor that implements common Retrieval Augmented Generation (RAG) flows using the
* building blocks defined in the {@link org.springframework.ai.rag} package and following
* the Modular RAG Architecture.
* <p>
Expand All @@ -55,10 +58,12 @@
* @see <a href="http://export.arxiv.org/abs/2407.21059">arXiv:2407.21059</a>
* @see <a href="https://export.arxiv.org/abs/2312.10997">arXiv:2312.10997</a>
*/
public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {
public final class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAroundAdvisor {

public static final String DOCUMENT_CONTEXT = "rag_document_context";

private final List<QueryTransformer> queryTransformers;

private final DocumentRetriever documentRetriever;

private final QueryAugmentor queryAugmentor;
Expand All @@ -67,12 +72,15 @@ public class RetrievalAugmentationAdvisor implements CallAroundAdvisor, StreamAr

private final int order;

public RetrievalAugmentationAdvisor(DocumentRetriever documentRetriever, @Nullable QueryAugmentor queryAugmentor,
@Nullable Boolean protectFromBlocking, @Nullable Integer order) {
public RetrievalAugmentationAdvisor(List<QueryTransformer> queryTransformers, DocumentRetriever documentRetriever,
@Nullable QueryAugmentor queryAugmentor, @Nullable Boolean protectFromBlocking, @Nullable Integer order) {
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
Assert.noNullElements(queryTransformers, "queryTransformers cannot contain null elements");
Assert.notNull(documentRetriever, "documentRetriever cannot be null");
this.queryTransformers = queryTransformers;
this.documentRetriever = documentRetriever;
this.queryAugmentor = queryAugmentor != null ? queryAugmentor : ContextualQueryAugmentor.builder().build();
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : false;
this.protectFromBlocking = protectFromBlocking != null ? protectFromBlocking : true;
this.order = order != null ? order : 0;
}

Expand Down Expand Up @@ -119,30 +127,45 @@ private AdvisedRequest before(AdvisedRequest request) {
Map<String, Object> context = new HashMap<>(request.adviseContext());

// 0. Create a query from the user text and parameters.
Query query = new Query(new PromptTemplate(request.userText(), request.userParams()).render());
Query originalQuery = new Query(new PromptTemplate(request.userText(), request.userParams()).render());

// 1. Transform original user query based on a chain of query transformers.
Query transformedQuery = originalQuery;
for (var queryTransformer : queryTransformers) {
transformedQuery = queryTransformer.apply(transformedQuery);
}

// 1. Retrieve similar documents for the original query.
List<Document> documents = this.documentRetriever.retrieve(query);
// 2. Retrieve similar documents for the original query.
List<Document> documents = this.documentRetriever.retrieve(transformedQuery);
context.put(DOCUMENT_CONTEXT, documents);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a RAGContext that in turn contains the documents so that for the future when additional things are added to the context for RAG, it is all in a central spot?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We agreed to talk more about this when we reconsider the design of ChatResponse and how to propagate back evidence from different intermediate steps in a ChatClient, such as retrieved documents and called functions.


// 2. Augment user query with the contextual data.
Query augmentedQuery = this.queryAugmentor.augment(query, documents);
// 3. Augment user query with the document contextual data.
Query augmentedQuery = this.queryAugmentor.augment(transformedQuery, documents);

return AdvisedRequest.from(request).withUserText(augmentedQuery.text()).withAdviseContext(context).build();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering about the multimodal/media aspects of this as it focused on text and any media for the request is lost.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AdvisedRequest has a separate Media property for the media elements, so they are currently propagated correctly.

}

private AdvisedResponse after(AdvisedResponse advisedResponse) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could make protected, along with before if someone wants to customize the algorithm.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed, the customisation approach will be based on implementing a custom Advisor and using the available RAG components to build a custom RAG flow, whereas all these out-of-the-box implementations will be "final" following the Spring Security approach.

ChatResponse.Builder chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
ChatResponse.Builder chatResponseBuilder;
if (advisedResponse.response() == null) {
chatResponseBuilder = ChatResponse.builder();
}
else {
chatResponseBuilder = ChatResponse.builder().from(advisedResponse.response());
}
chatResponseBuilder.withMetadata(DOCUMENT_CONTEXT, advisedResponse.adviseContext().get(DOCUMENT_CONTEXT));
return new AdvisedResponse(chatResponseBuilder.build(), advisedResponse.adviseContext());
}

private Predicate<AdvisedResponse> onFinishReason() {
return advisedResponse -> advisedResponse.response()
.getResults()
.stream()
.anyMatch(result -> result != null && result.getMetadata() != null
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
return advisedResponse -> {
ChatResponse chatResponse = advisedResponse.response();
return chatResponse != null && chatResponse.getResults() != null
&& chatResponse.getResults()
.stream()
.anyMatch(result -> result != null && result.getMetadata() != null
&& StringUtils.hasText(result.getMetadata().getFinishReason()));
};
}

@Override
Expand All @@ -157,6 +180,8 @@ public int getOrder() {

public static final class Builder {

private final List<QueryTransformer> queryTransformers = new ArrayList<>();

private DocumentRetriever documentRetriever;

private QueryAugmentor queryAugmentor;
Expand All @@ -168,6 +193,18 @@ public static final class Builder {
private Builder() {
}

public Builder queryTransformers(List<QueryTransformer> queryTransformers) {
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
this.queryTransformers.addAll(queryTransformers);
return this;
}

public Builder queryTransformers(QueryTransformer... queryTransformers) {
Assert.notNull(queryTransformers, "queryTransformers cannot be null");
this.queryTransformers.addAll(Arrays.asList(queryTransformers));
return this;
}

public Builder documentRetriever(DocumentRetriever documentRetriever) {
this.documentRetriever = documentRetriever;
return this;
Expand All @@ -189,7 +226,7 @@ public Builder order(Integer order) {
}

public RetrievalAugmentationAdvisor build() {
return new RetrievalAugmentationAdvisor(this.documentRetriever, this.queryAugmentor,
return new RetrievalAugmentationAdvisor(this.queryTransformers, this.documentRetriever, this.queryAugmentor,
this.protectFromBlocking, this.order);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* RAG Module: Query Analysis.
* <p>
* This package encompasses all components involved in the pre-retrieval phase of a
* retrieval augmented generation flow. Queries are transformed, expanded, or constructed
* so to enhance the effectiveness and accuracy of the subsequent retrieval phase.
*/
@NonNullApi
@NonNullFields
package org.springframework.ai.rag.analysis;

import org.springframework.lang.NonNullApi;
import org.springframework.lang.NonNullFields;
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.springframework.ai.rag.analysis.query.expansion;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.prompt.PromptTemplate;
import org.springframework.ai.rag.Query;
import org.springframework.ai.util.PromptAssert;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.StringUtils;

import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;

/**
* Expander that implements semantic query expansion for retrieval-augmented generation
* flows. It uses a large language model to generate multiple semantically diverse
* variations of an input query to capture different perspectives and improve document
* retrieval coverage.
*
* <p>
* Example usage: <pre>{@code
* MultiQueryExpander expander = MultiQueryExpander.builder()
* .chatClientBuilder(chatClientBuilder)
* .numberOfQueries(3)
* .build();
* List<Query> queries = expander.expand(new Query("How to run a Spring Boot app?"));
* }</pre>
*
* @author Thomas Vitale
* @since 1.0.0
*/
public final class MultiQueryExpander implements QueryExpander {

private static final Logger logger = LoggerFactory.getLogger(MultiQueryExpander.class);

private static final PromptTemplate DEFAULT_PROMPT_TEMPLATE = new PromptTemplate("""
You are an expert at information retrieval and search optimization.
Your task is to generate {number} different versions of the given query.

Each variant must cover different perspectives or aspects of the topic,
while maintaining the core intent of the original query. The goal is to
expand the search space and improve the chances of finding relevant information.

Do not explain your choices or add any other text.
Provide the query variants separated by newlines.

Original query: {query}

Query variants:
""");

private static final Boolean DEFAULT_INCLUDE_ORIGINAL = false;

private static final Integer DEFAULT_NUMBER_OF_QUERIES = 3;

private final ChatClient chatClient;

private final PromptTemplate promptTemplate;

private final boolean includeOriginal;

private final int numberOfQueries;

public MultiQueryExpander(ChatClient.Builder chatClientBuilder, @Nullable PromptTemplate promptTemplate,
@Nullable Boolean includeOriginal, @Nullable Integer numberOfQueries) {
Assert.notNull(chatClientBuilder, "chatClientBuilder cannot be null");

this.chatClient = chatClientBuilder.build();
this.promptTemplate = promptTemplate != null ? promptTemplate : DEFAULT_PROMPT_TEMPLATE;
this.includeOriginal = includeOriginal != null ? includeOriginal : DEFAULT_INCLUDE_ORIGINAL;
this.numberOfQueries = numberOfQueries != null ? numberOfQueries : DEFAULT_NUMBER_OF_QUERIES;

PromptAssert.templateHasRequiredPlaceholders(this.promptTemplate, "number", "query");
}

@Override
public List<Query> expand(Query query) {
Assert.notNull(query, "query cannot be null");

logger.debug("Generating {} query variants", numberOfQueries);

var response = chatClient.prompt()
.user(user -> user.text(promptTemplate.getTemplate())
.param("number", numberOfQueries)
.param("query", query.text()))
.call()
.content();

if (response == null) {
logger.warn("Query expansion result is null. Returning the input query unchanged.");
return List.of(query);
}

var queryVariants = Arrays.asList(response.split("\n"));

if (CollectionUtils.isEmpty(queryVariants) || numberOfQueries != queryVariants.size()) {
logger.warn(
"Query expansion result does not contain the requested {} variants. Returning the input query unchanged.",
numberOfQueries);
return List.of(query);
}

var queries = queryVariants.stream().filter(StringUtils::hasText).map(Query::new).collect(Collectors.toList());

if (includeOriginal) {
logger.debug("Including the original query in the result");
queries.add(0, query);
}

return queries;
}

public static Builder builder() {
return new Builder();
}

public static final class Builder {

private ChatClient.Builder chatClientBuilder;

private PromptTemplate promptTemplate;

private Boolean includeOriginal;

private Integer numberOfQueries;

private Builder() {
}

public Builder chatClientBuilder(ChatClient.Builder chatClientBuilder) {
this.chatClientBuilder = chatClientBuilder;
return this;
}

public Builder promptTemplate(PromptTemplate promptTemplate) {
this.promptTemplate = promptTemplate;
return this;
}

public Builder includeOriginal(Boolean includeOriginal) {
this.includeOriginal = includeOriginal;
return this;
}

public Builder numberOfQueries(Integer numberOfQueries) {
this.numberOfQueries = numberOfQueries;
return this;
}

public MultiQueryExpander build() {
return new MultiQueryExpander(chatClientBuilder, promptTemplate, includeOriginal, numberOfQueries);
}

}

}
Loading