# Muti-Query
利用大模型将原本的单一查询转变为多个子查询，然后将这些子查询分别做rag，将结果进行汇总，排序。


In [None]:
# ！表示在jupter中运行shell命令
# % 更安全，自动安装到当前的内核环境中
%pip install langchain langchain_ollama
%pip install chromadb langchain_chormadb

In [None]:
from langchain_ollama import ChatOllama
from langchain_ollama import OllamaEmbeddings

embeddings = OllamaEmbeddings(model="bge-m3")

llm = ChatOllama(model="huihui_ai/deepseek-r1-abliterated:7b")
result = llm.invoke("你好？")
print(result.content)

for chunk in llm.stream("你是谁？"):
    print(chunk.content, end="", flush=True)

## 向量化存储

In [None]:
# 使用chromadb，将本地文件“chineseJH.txt”进行向量化存储
from langchain_community.document_loaders import TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma

import os

# 判断向量数据库目录是否存在
db = None
if not os.path.exists("chineseJH_chroma_db") or not os.listdir("chineseJH_chroma_db"):
    # 加载文本文件
    loader = TextLoader("./chineseJH.txt", encoding="utf-8")
    documents = loader.load()

    # 文本切分
    text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
    docs = text_splitter.split_documents(documents)

    # 创建Chroma向量数据库并存储
    db = Chroma.from_documents(docs, embeddings, persist_directory="chineseJH_chroma_db")
    db.persist()
    print("已成功将chineseJH.txt向量化并存储到chineseJH_chroma_db目录。")
else:
    db = Chroma(persist_directory="chineseJH_chroma_db", embedding_function=embeddings)
    print("chineseJH_chroma_db 已存在，无需重复向量化。")


# if not os.path.exists("55_db") or not os.listdir("55_db"):
#     # 加载文本文件
#     loader = TextLoader("./1-55.txt", encoding="utf-8")
#     documents = loader.load()

#     # 文本切分
#     text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
#     docs = text_splitter.split_documents(documents)

#     # 创建Chroma向量数据库并存储
#     db = Chroma.from_documents(docs, embeddings, persist_directory="55_db")
#     db.persist()
#     print("已成功将chineseJH.txt向量化并存储到55_db目录。")
# else:
#     db = Chroma(persist_directory="55_db", embedding_function=embeddings)
#     print("55_db 已存在，无需重复向量化。")

## 利用大模型生成同义问答

In [None]:
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain.schema import BaseOutputParser

template = """
请为下面的问题生成3个不同表达方式的同义句,并且每个问题占一行，问题之间没有空行，严格按照上述格式：
问题：{query}
"""
prompt = PromptTemplate(template=template, input_variables=["query"])


# class LineListOutputParser(BaseOutputParser):
#     def parse(self, text: str):
#         # 按行分割，去除空行和首尾空白
#         return [line.strip() for line in text.strip().split('\n') if line.strip()]

user_query = "令狐冲最后怎么样了？"
get_multi_query_chain = prompt | llm | StrOutputParser() | (lambda x: [str.strip() for str in x.split('\n')])

multi_queries = get_multi_query_chain.invoke({"query":user_query})
print(multi_queries)

## 多查询检索合并

In [None]:
from langchain.load import dumps

print("="*20 + "测试查询" + "="*20)

# 多query分别检索后合并
retrieves = []
for query in multi_queries:
    similar_docs = db.similarity_search(query, k = 5)
    retrieves.extend(similar_docs)
print(retrieves)


print("="*50)

In [None]:
# 构建链
retriver = db.as_retriever()

# 将docs列表[[],[],[],...]转化为list并去重
def get_unique_doc(doc_lists):
    final_list = []
    final_set = set()
    for doc_list in doc_lists:
        for doc in doc_list:
            if doc.page_content not in final_set:
                final_set.add(doc.page_content)
                final_list.append(doc)
    return final_list


retriver_chain = get_multi_query_chain | retriver.map() | get_unique_doc

docs = retriver_chain.invoke({"query": "令狐冲会的武功有哪些？"})
print(docs)
print(len(docs))

## 整合生成

In [None]:
from operator import itemgetter
# 用于创造提取函数，从dict中提取字段的函数
question_extractor = itemgetter("question")  # 创建提取函数


template = """请根据以下【上下文】内容，认真回答【问题】。  
【上下文】
{context}

【问题】
{question}

【请在下方作答】
"""
prompt = PromptTemplate.from_template(template)

final_rag_chain = (
    {  
        "context": {"query" : itemgetter("question") } | retriver_chain | StrOutputParser(),
        "question": itemgetter("question")
    }
    | prompt
    | llm
    | StrOutputParser()
)

print("="*50 + "RAG效果")

# str = input("输入问题:")
# print(str)

str = ""

question = "辟邪剑法都被谁拿到过？" if str == "" else str
for chunk in final_rag_chain.stream({"question": question}):
    print(chunk, end="", flush=True)

print("\n" + "="*50 + "无RAG效果")
simple_chain = llm | StrOutputParser()
for chunk in simple_chain.stream(question):
    print(chunk, end="", flush=True)