In [None]:
import os 
from dotenv import load_dotenv
from langchain_community.document_loaders import PyPDFLoader,TextLoader
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain_community.chat_models import ChatTongyi
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables.history import RunnableWithMessageHistory
from langchain_core.runnables import RunnableLambda
from langchain_core.messages import HumanMessage, AIMessage, SystemMessage
from dashscope.rerank.text_rerank import TextReRank
from operator import itemgetter
from pylatexenc.latexwalker import LatexWalker, LatexEnvironmentNode
from langchain_text_splitters import RecursiveCharacterTextSplitter

读取环境中的api_key

In [None]:
# load_dotenv()
os.environ["DASHSCOPE_API_KEY"] = "your_dashcope_api_key"

基于 LaTeX 感知的智能分割

In [None]:
# from langchain_core.documents import Document

# # === LaTeX 感知的智能分割函数 ===
# def latex_aware_split_documents(documents, chunk_size=800, chunk_overlap=50, preserve_envs=None):
#     if preserve_envs is None:
#         preserve_envs = {
#             'theorem', 'lemma', 'definition', 'proof',
#             'corollary', 'example', 'proposition', 'remark'
#         }

#     all_chunks = []

#     for doc in documents:
#         text = doc.page_content
#         metadata = doc.metadata

#         # 尝试用 pylatexenc 解析 LaTeX
#         try:
#             lw = LatexWalker(text)
#             nodelist, _, _ = lw.get_latex_nodes(pos=0)
#         except Exception as e:
#             print(f"⚠️ LaTeX parsing failed for {metadata.get('source', 'unknown')}, falling back to plain split: {e}")
#             # 回退到普通文本分割
#             fallback_splitter = RecursiveCharacterTextSplitter(
#                 chunk_size=chunk_size,
#                 chunk_overlap=chunk_overlap,
#                 separators=["\n\n", "\n", " ", ""]
#             )
#             chunks = fallback_splitter.split_text(text)
#             all_chunks.extend([
#                 Document(page_content=chunk, metadata=metadata)
#                 for chunk in chunks if chunk.strip()
#             ])
#             continue

#         # 构建语义单元：普通文本累积，重要环境单独成块
#         semantic_units = []
#         current_text = ""

#         for node in nodelist:
#             if isinstance(node, LatexEnvironmentNode) and node.environmentname in preserve_envs:
#                 if current_text.strip():
#                     semantic_units.append(("text", current_text))
#                     current_text = ""
#                 env_str = node.latex_verbatim()
#                 semantic_units.append(("env", env_str))
#             else:
#                 current_text += node.latex_verbatim()

#         if current_text.strip():
#             semantic_units.append(("text", current_text))

#         # 合并为最终 chunks
#         current_chunk = ""
#         for unit_type, content in semantic_units:
#             if unit_type == "env":
#                 # 重要环境：单独成块（即使超限）
#                 if current_chunk.strip():
#                     all_chunks.append(Document(page_content=current_chunk, metadata=metadata))
#                     current_chunk = ""
#                 if content.strip():
#                     all_chunks.append(Document(page_content=content, metadata=metadata))
#             else:
#                 # 普通文本：按 chunk_size 切分
#                 text_splitter = RecursiveCharacterTextSplitter(
#                     chunk_size=chunk_size,
#                     chunk_overlap=chunk_overlap,
#                     separators=["\n\n", "\n", " ", ""]
#
#                             
#                                                               )



#                 sub_chunks = text_splitter.split_text(content)
#                 for sc in sub_chunks:
#                     if len(current_chunk) + len(sc) <= chunk_size:
#                         current_chunk += sc
#                     else:
#                         if current_chunk.strip():
#                             all_chunks.append(Document(page_content=current_chunk, metadata=metadata))
#                         current_chunk = sc
#         if current_chunk.strip():
#             all_chunks.append(Document(page_content=current_chunk, metadata=metadata))

#     # 最终过滤空 chunk
#     return [d for d in all_chunks if d.page_content.strip()]


# # === 主流程 ===
# loader_3 = TextLoader(file_path=file_path_3, encoding="utf-8")
# docs = loader_3.load()

# # 使用 LaTeX 感知分割器（关键修改！）
# splitte_docs = latex_aware_split_documents(
#     docs,
#     chunk_size=800,      # 建议 ≥800 以容纳完整定理
#     chunk_overlap=50
# )

外部知识文档地址，以及加载

In [None]:
# 只是用 loader_3 作为外部文档
file_path_3="/./thesis-some chapter.txt" 
loader_3 = TextLoader(file_path = file_path_3)
docs = loader_3.load()

In [None]:
# 基于txt 格式的文档分割
text_splitter = RecursiveCharacterTextSplitter(
    chunk_size=500,
    chunk_overlap=50,
    separators=["\n\n", "\n", " ", "", 
                                r"\end{lemma}",
                                r"\end{theorem}",
                                r"\end{definition}",
                                r"\end{proof}",
                                r"\end{corollary}",
                                r"\end{example}",
                                r"\end{remark}"
                            ]
)

splitte_docs = text_splitter.split_documents(docs)


# 初始化 embedding 模型，指定 model="text-embedding-v4"
embeddings = DashScopeEmbeddings(
    model="text-embedding-v4"
)


# 向量数据库储存地址
persist_dir = "/home/wenhao/projects/Hilbert scheme/chroma_db"

# 向量数据库， 如果已存在，不需要重复创建
if os.path.exists(persist_dir) and os.listdir(persist_dir):
    db = Chroma(persist_directory=persist_dir, embedding_function=embeddings)
else:
    db = Chroma.from_documents(
        documents=splitte_docs,
        embedding=embeddings,
        persist_directory=persist_dir
    )

retriever = db.as_retriever( search_kwargs={"k": 20} )

Chunks 的 reranking

In [None]:
# reranker
def rerank_retriever(query):
    docs = retriever.invoke(query)
    clean_docs = [d for d in docs if d.page_content.strip()]
    initial_docs = [d.page_content for d in clean_docs]

    # print(f"Number of input docs to reranker: {len(initial_docs)}")

    if not initial_docs:
        return []

    response = TextReRank.call(
        model="qwen3-rerank",
        query=query,
        documents=initial_docs,
        top_n= 5
        # instruct="Given a question, retrieve passages that best answer it."
    )
    
    text_to_doc = {d.page_content: d for d in clean_docs}
    
    reranked_docs = []
    for item in response.output.results:
        idx = item.index  # ✅ 正确方式：使用 index
        if 0 <= idx < len(initial_docs):
            text = initial_docs[idx]
            if text in text_to_doc:
                reranked_docs.append(text_to_doc[text])

    return reranked_docs

reranker 的调试

In [None]:
# query = "What is Punctual Hilbert scheme?"
# results = db.similarity_search_with_score(query, k=5)
# for doc, score in results:
#     print(f"Score: {score:.4f}")
#     print(doc.page_content)
#     print("-" * 40)

# top5_chunks = rerank_retriever(query)

# # 打印结果
# for i, doc in enumerate(top5_chunks, 1):
#     print(f"Top {i}:")
#     print(doc.page_content)
#     print("-" * 60)

Score: 0.3790
\section{Punctual Hilbert schemes}\label{Punctual Hilbert schemes}
----------------------------------------
Score: 0.6119
embedding. In general, even when the defining equations of punctual Hilbert schemes are known, it remains difficult to understand their geometry or to compute their invariants.
----------------------------------------
Score: 0.6227
The main objects of study in this article are the punctual Hilbert schemes of irreducible curve singularities defined over an algebraically closed field  $ k $  of characteristic～ $ 0 $ . Given a curve singularity  $ (C,O) $  and an integer  $ \ell \in \mathbb{N} $ , the  $ \ell $ -th punctual Hilbert scheme of  $ (C,O) $ , denoted  $ C^{[\ell]} $ , is the moduli space parametrizing  $ 0 $ -dimensional subschemes of  $ (C,O) $  of length  $ \ell $ . This is a special case of Grothendieck's
----------------------------------------
Score: 0.6240
Recently, punctual Hilbert schemes have attracted considerable attention from two 

多轮对话


In [None]:
#  定义不同轮次的特定prompt
def get_prompt_for_round(round_num):
    """根据对话轮次返回特定prompt"""
    if round_num == 1:
        return "专注于介绍相关数学的基础知识。请用简单易懂的语言回答问题。"
    elif round_num == 2:
        return "专注于深入解释相关数学的概念。请提供更详细的技术细节。"
    elif round_num >= 3:
        return "，专注于解决复杂问题和提供专业见解。请提供深度分析。"
    return "你是一个数学教授，根据上下文回答问题。"


# 4. 构建链式结构
prompt_template = ChatPromptTemplate.from_messages([
    ("system", "你是一个严谨的数学专家。\n"
    #  "用户额外提供了以下参考文档，如何文档非空白内容，请结合该文档和检索结果回答：\n{file}\n\n"
     "检索到的相关上下文如下：\n{context}\n\n"
     "如果以上内容不包含答案，请回答：'根据已有资料无法回答。'" "{system_prompt}"
    ),
    MessagesPlaceholder(variable_name="history"),
    ("human", "{question}")

])


# 初始化LLM
llm = ChatTongyi(model = "qwen-turbo",temperature = 0)


def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


# 创建链
chain = (
{     
    "context": itemgetter("question") | RunnableLambda(rerank_retriever)| format_docs,
    "system_prompt": RunnableLambda(itemgetter("round_num")) | RunnableLambda(get_prompt_for_round),
    "history": itemgetter("history"),
    "question": itemgetter("question"),
    # "file": itemgetter("file")
}
    | prompt_template
    | llm
)

# 创建对话历史管理器
store = {}

def get_session_history(session_id: str):
    if session_id not in store:
        store[session_id] = ChatMessageHistory()
    return store[session_id]

conversational_chain = RunnableWithMessageHistory(
    chain,
    get_session_history,
    input_messages_key="question",
    history_messages_key="history"
)

#  运行多轮对话示例
session_id = "user_session_1"

用户query的重写

In [None]:
llm_rewriting_query = ChatTongyi(
    model="qwen3-1.7b",
    temperature=0,
    api_key=os.getenv("DASHSCOPE_API_KEY"),
    model_kwargs={"enable_thinking": False}
)


prompt_rewrite = ChatPromptTemplate.from_messages([
    ("system", 
     "你是一个查询改写助手。请根据对话历史，将用户最新的问题改写成一个独立、完整、清晰的问题。\n"
     "不要回答问题，不要添加解释，只输出改写后的问题。"
    ),
    ("human", "对话历史：{history}\n用户最新问题：{question}")
])


def rewriting_query(question: str, history: list ) -> str:
    chain = prompt_rewrite | llm_rewriting_query
    history_str = "\n".join([f"{'用户' if isinstance(msg, HumanMessage) else 'AI'}: {msg.content}" for msg in history])
    response = chain.invoke({"question": question,
                            "history": history_str})
    return response.content.strip()  


In [4]:
query = "Hilbert scheme 是什么？"
rewritten_query = rewriting_query(query, get_session_history(session_id).messages)

# loader_2 = TextLoader(file_path = file_path_2)
# full_docs = loader_2.load()                         # 返回 Document 对象列表

# # 合并所有页面文本
# pdf_text = "\n".join([page.page_content for page in full_docs])

# 第一轮对话
response = conversational_chain.invoke(
    {     
        #  "file": pdf_text, 
        "question": rewritten_query,
        "round_num": 1,
           
    },
    config={"configurable": {"session_id": session_id}}
)
print(f"轮次1: {response}\n")

轮次1: content='Hilbert scheme 是数学中一个重要的概念，主要用于研究几何对象（如曲线、曲面等）的子集的结构。\n\n简单来说，**Hilbert scheme** 是一个“空间”，它把具有某种特定性质的子集（比如长度为 $\\ell$ 的 0-维子集）全部“放在一起”，并按照某种方式分类。这个“空间”中的每一个点，对应一个特定的子集。\n\n举个例子：  \n假设你有一个曲线 $C$，在它的某一点 $O$ 附近，你想研究所有“长度为 $\\ell$”的子集。这些子集可以看作是某些代数方程的解的集合。Hilbert scheme 就是把这些子集都收集起来，并给出它们的结构和性质。\n\n更具体地说，对于一个曲线奇点 $(C, O)$，它的 **$\\ell$-th punctual Hilbert scheme**（记作 $C^{[\\ell]}$）就是所有长度为 $\\ell$ 的 0-维子集的集合，这些子集都位于点 $O$ 附近。\n\nHilbert scheme 在代数几何中有广泛应用，特别是在研究几何对象的变形、对称性以及不变量等方面。' additional_kwargs={} response_metadata={'model_name': 'qwen-turbo', 'finish_reason': 'stop', 'request_id': 'bfbf9490-d852-437b-b2a8-483924b9044d', 'token_usage': {'input_tokens': 543, 'output_tokens': 262, 'prompt_tokens_details': {'cached_tokens': 0}, 'total_tokens': 805}} id='lc_run--019bc97f-7b10-77c2-adfd-9c07e156f2b7-0' tool_calls=[] invalid_tool_calls=[]

