In [150]:
from sparkai.llm.llm import ChatSparkLLM, ChunkPrintHandler
from sparkai.core.messages import ChatMessage
from langchain.schema import Document
from langchain.schema.retriever import BaseRetriever
from typing import List
from pydantic import Field
from langchain.document_loaders import DirectoryLoader
import os

# 设置 tesseract 的路径
os.environ["TESSERACT_PATH"] = r"C:\Program Files\Tesseract-OCR\tesseract.exe"
os.environ["PATH"] += os.pathsep + r"C:\Program Files\Tesseract-OCR"

SPARKAI_URL = 'wss://spark-api.xf-yun.com/v4.0/chat'
SPARKAI_APP_ID = '2087ff0e'
SPARKAI_API_SECRET = 'ZDg4Mjg2NDlhNWUwNzA5ZjM5M2YxOTI5'
SPARKAI_API_KEY = '562ea31be2df40e0b808ad7d03145cfe'
SPARKAI_DOMAIN = '4.0Ultra'

In [151]:
class SparkAPIRetriever(BaseRetriever):
    document_dir: str = Field(..., description="The directory where documents are stored")
    documents: List[Document] = []
    loader: DirectoryLoader = None

    def __init__(self, document_dir: str, **kwargs):
        super().__init__(**kwargs)
        self.document_dir = document_dir
        self.loader = DirectoryLoader(path=self.document_dir, show_progress=True)
        self.documents = self.load_documents()

    def load_documents(self) -> List[Document]:
        # 使用 DirectoryLoader 加载文档
        return self.loader.load()

    def _get_relevant_documents(self, query: str, *, run_manager=None) -> List[Document]:
        return [doc for doc in self.documents if query.lower() in doc.page_content.lower()]

    async def _aget_relevant_documents(self, query: str, *, run_manager=None) -> List[Document]:
        return await super()._aget_relevant_documents(query, run_manager=run_manager)

In [152]:
def test_spark_api_retriever():
    document_dir = '../uploads'  # 确保这是正确的路径
    retriever = SparkAPIRetriever(document_dir=document_dir)

    documents = retriever.load_documents()
    print(f"Loaded {len(documents)} documents")

    query = "二叉树"
    relevant_documents = retriever._get_relevant_documents(query=query)
    print(f"Found {len(relevant_documents)} relevant documents for query '{query}'")

    for doc in relevant_documents:
        # print(doc)
        print(f"Relevant Document: {doc.metadata['source']}")

# 运行测试函数
test_spark_api_retriever()


The MIME type of '..\\uploads\\二叉树.md' is "cannot open `..\\uploads\\\\344\\272\\214\\345\\217\\211\\346\\240\\221.md' (Illegal byte sequence)". This file type is not currently supported in unstructured.
The MIME type of '..\\uploads\\二叉树2.pdf' is "cannot open `..\\uploads\\\\344\\272\\214\\345\\217\\211\\346\\240\\2212.pdf' (Illegal byte sequence)". This file type is not currently supported in unstructured.
100%|██████████| 4/4 [00:00<00:00,  4.28it/s]
The MIME type of '..\\uploads\\二叉树.md' is "cannot open `..\\uploads\\\\344\\272\\214\\345\\217\\211\\346\\240\\221.md' (Illegal byte sequence)". This file type is not currently supported in unstructured.
The MIME type of '..\\uploads\\二叉树2.pdf' is "cannot open `..\\uploads\\\\344\\272\\214\\345\\217\\211\\346\\240\\2212.pdf' (Illegal byte sequence)". This file type is not currently supported in unstructured.
100%|██████████| 4/4 [00:00<00:00,  4.21it/s]

Loaded 4 documents
Found 3 relevant documents for query '二叉树'
Relevant Document: ..\uploads\binary_Tree.txt
Relevant Document: ..\uploads\二叉树.md
Relevant Document: ..\uploads\二叉树2.pdf





In [153]:
class SparkAPILLM:
    def __init__(self, api_url: str, app_id: str, api_secret: str, api_key: str, domain: str):
        self.api_url = api_url
        self.app_id = app_id
        self.api_secret = api_secret
        self.api_key = api_key
        self.domain = domain

    def generate(self, prompt: str):
        spark = ChatSparkLLM(
            spark_api_url=self.api_url,
            spark_app_id=self.app_id,
            spark_api_key=self.api_key,
            spark_api_secret=self.api_secret,
            spark_llm_domain=self.domain,
            streaming=False,
        )
        messages = [ChatMessage(role="user", content=prompt)]
        handler = ChunkPrintHandler()
        response = spark.generate([messages], callbacks=[handler])

        # 确保正确访问 LLMResult 的内容
        generations = response.generations
        if generations and generations[0]:
            text = generations[0][0].text
            # print(f"Response from LLM: {text}")
            return {"content": text}
        return {"content": "No response from LLM"}


In [154]:
class RetrievalQA:
    def __init__(self, retriever, llm):
        self.retriever = retriever
        self.llm = llm

    def __call__(self, inputs):
        query = inputs["query"]
        docs = self.retriever._get_relevant_documents(query)
        context = " ".join([doc.page_content for doc in docs])
        response = self.llm.generate(context + "\n\n" + query)
        return response["content"]

In [155]:
def create_rag_chain():
    retriever = SparkAPIRetriever(
        document_dir='../uploads'
    )
    llm = SparkAPILLM(
        api_url=SPARKAI_URL,
        app_id=SPARKAI_APP_ID,
        api_secret=SPARKAI_API_SECRET,
        api_key=SPARKAI_API_KEY,
        domain=SPARKAI_DOMAIN
    )
    qa_chain = RetrievalQA(retriever=retriever, llm=llm)
    return qa_chain

In [156]:
# 测试 RAG 链
rag_chain = create_rag_chain()
inputs = {"query": "二叉树每个节点最多有几个孩子"}
response = rag_chain(inputs)
print(response)


The MIME type of '..\\uploads\\二叉树.md' is "cannot open `..\\uploads\\\\344\\272\\214\\345\\217\\211\\346\\240\\221.md' (Illegal byte sequence)". This file type is not currently supported in unstructured.
The MIME type of '..\\uploads\\二叉树2.pdf' is "cannot open `..\\uploads\\\\344\\272\\214\\345\\217\\211\\346\\240\\2212.pdf' (Illegal byte sequence)". This file type is not currently supported in unstructured.
100%|██████████| 4/4 [00:00<00:00,  4.11it/s]


**二叉树的每个节点最多可以有两个子节点**。在数据结构中，二叉树是一种常见的树形结构，它的每个节点包含一个数据元素以及两个指向其子节点的链接，分别称为“左子节点”和“右子节点”。

下面将详细探讨关于二叉树的性质、类型及其它相关概念：

1. **基本性质**
   - **度的限制**：二叉树的每个节点最多只能有两个子节点。
   - **层次节点数**：第i层上的最大节点数目为2^(i-1)^，其中i是从1开始的层数。
   - **深度与节点数的关系**：深度为k的二叉树至多有2^k^ - 1个节点。
   - **终端节点与分支节点关系**：如果n0表示终端节点（叶子节点）的数量，n2表示度数为2的节点数量，则n0 = n2 + 1。

2. **特殊类型**
   - **满二叉树**：每一层都被完全填满的二叉树，其节点数目符合2^h^ - 1的规律，其中h为树的高度。
   - **完全二叉树**：除了最后一层可能没有被完全填满，其余每一层都被完全填满的二叉树，且最后一层的节点都靠左排列。

为了进一步理解二叉树的遍历方式，可以考虑以下几点：

- **前序遍历**：先访问根节点，再递归访问左子树，最后递归访问右子树。
- **中序遍历**：先递归访问左子树，再访问根节点，最后递归访问右子树。
- **后序遍历**：先递归访问左子树，再递归访问右子树，最后访问根节点。

总的来说，二叉树作为一种基础而重要的数据结构，在计算机科学领域中有着广泛的应用。它的特性包括每个节点最多有两个子节点，并且存在多种特殊的二叉树结构以适应不同的应用场景。二叉树的遍历方法是理解和操作这种数据结构的关键，其中包括了递归和迭代两种主要的实现方式。
