In [1]:
import os
from langchain.document_loaders import PyPDFLoader
from typing import List, Tuple, Dict, Any
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.vectorstores import FAISS
from langchain.tools.retriever import create_retriever_tool
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_community.embeddings import DashScopeEmbeddings
from langchain_community.chat_models import ChatOllama
from langchain.chains import RetrievalQA
from datasets import load_dataset
from langchain.schema import Document
import numpy as np
import json
import faiss

# DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")

In [2]:
from llama_cpp import Llama
from langchain.embeddings.base import Embeddings

# 自定义 LangChain 的 Embeddings 类封装
class LlamaCppEmbeddings(Embeddings):
    def __init__(self, model_path: str):
        self.llm = Llama(model_path=model_path, embedding=True)

    def embed_documents(self, texts):
        # return [self.llm.embed(text)["data"][0]["embedding"] for text in texts]
        embeddings = []
        for text in texts:
            result = self.llm.embed(text)
            if isinstance(result, list) and isinstance(result[0], list):
                embeddings.append(result[0])
            else:
                embeddings.append(result)
        return embeddings

    def embed_query(self, text):
        # return self.llm.embed(text)["data"][0]["embedding"]
        result = self.llm.embed(text)
        return result[0] if isinstance(result, list) and isinstance(result[0], list) else result

In [14]:
class Client:
    """
    轻量级rag客户端，负责数据集加载、向量存储构建与检索。
    """
    def __init__(self, model_path: str = "./models/Qwen3-Embedding/Qwen3-Embedding-0.6B-Q8_0.gguf", 
                vectorstore_path: str = "faiss_db"): # dashscope_api_key: str,使用api调用embedding模型
        os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
        self.vectorstore_path = vectorstore_path
        # self.embeddings = DashScopeEmbeddings(
        #     model="text-embedding-v1",
        #     dashscope_api_key=dashscope_api_key
        # )
        self.embeddings = LlamaCppEmbeddings(model_path=model_path)
        self.db: FAISS = None
        self.retriever = None

    def _chunk_text(self, text: str, chunk_size=800, overlap= 200) -> list[str]:
        """
        将文本分块处理，使用递归字符分割器。
        """
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=overlap,
            length_function=len
        )
        return splitter.split_text(text)

    # 读取PDF文件并提取文本内容
    def _read_pdfs(self, pdf_paths: List[str]) -> List[Document]:
        docs = []
        for path in pdf_paths:
            loader = PyPDFLoader(path)
            pages = loader.load_and_split()
            docs.extend(pages)
        return docs

    # 读取JSON文件夹中的所有文件
    def _load_json_folder(self, folder_path: str) -> List[Document]:
        docs = []
        for filename in os.listdir(folder_path):
            if not filename.endswith('.json'):
                continue
            filepath = os.path.join(folder_path, filename)
            with open(filepath, encoding='utf-8') as f:
                data = json.load(f)
            content = f"{data.get('title', '')}\n{data.get('content', '')}".strip()
            if content:
                docs.append(Document(page_content=content, metadata={'source': filepath}))
        return docs
    
    # 在线读取数据集
    def _streaming_load_dataset(self, sample_size=100, language='en', date_version='20231101') -> List[str]:
        # 启用streaming模式在线读取huggingface datasets
        dataset = load_dataset("wikimedia/wikipedia", f'{date_version}.{language}', streaming=True)
        docs = []
        for i, item in enumerate(dataset['train']):
            if i >= sample_size:
                break
            text = item.get('text', '')
            title = item.get('title', '')
            if not text:
                continue
            # # 抽取前 5000 字，避免过长
            # snippet = text[:5000]
            meta = {'source': f'wikipedia://{language}/{item.get("id")}'}
            docs.append(Document(page_content=f"{title}\n{text}", metadata=meta))
        print(f"Streamed {len(docs)} Wikipedia docs.")
        return docs
    
    def build_vectorstore(self, sample_size=100, batch_size=10, 
                          streaming=False, folder_path=None, pdf_paths:List[str]=None):
        docs = []
        if streaming:
            # 在线读取数据集
            docs.extend(self._streaming_load_dataset(sample_size))
        elif folder_path is not None and pdf_paths is None:
            # 从指定文件夹加载JSON文件
            docs.extend(self._load_json_folder(folder_path))
        elif pdf_paths is not None:
            # 从PDF文件加载
            docs.extend(self._read_pdfs(pdf_paths))

        # 分块并批量处理
        for i, doc in enumerate(docs):
            texts, metadatas, all_chunk= [], [], []
            all_chunk = self._chunk_text(doc.page_content)
            for j, chunk in enumerate(all_chunk):
                texts.append(chunk)
                metadatas.append(doc.metadata)
                # 每 batch_size 保存一次，防止内存溢出
                if len(texts) >= batch_size or j == len(all_chunk) - 1:
                    if self.db is None:
                        self.db = FAISS.from_texts(texts, embedding=self.embeddings, metadatas=metadatas)
                    else:
                        self.db.add_texts(texts, metadatas=metadatas)
                    texts.clear()
                    metadatas.clear()
            print(f"Processed {i+1}/{len(docs)} articles...")

        # 保存向量库
        if self.db:
            self.db.save_local(self.vectorstore_path)
            print(f"Vectorstore saved to {self.vectorstore_path}")
        else:
            print("No data processed.")

    def load_vectorstore(self) -> None:
        """
        加载已保存的向量存储，并初始化检索器。
        """
        if not os.path.exists(self.vectorstore_path):
            raise FileNotFoundError(f"Vectorstore directory '{self.vectorstore_path}' not found.")
        self.db = FAISS.load_local(
            self.vectorstore_path,
            embeddings=self.embeddings,
            allow_dangerous_deserialization=True
        )
        print(f"Vectorstore {self.vectorstore_path} loaded.")

    def retrieve(self, query:str, top_k=4, use_mmr=False):
        """
        通过query在FAISS向量库中检索k个最相似文档，
        返回每个Document对象、其特征向量及相似度得分。
        """
        # 检查向量库是否已加载
        if self.db is None:
            raise ValueError("Vectorstore尚未加载，请先调用load_vectorstore或build_vectorstore")

        query_vec = np.array(self.embeddings.embed_query(query), dtype=np.float32).tolist()

        if use_mmr:
            # 执行MMR搜索
            docs = self.db.max_marginal_relevance_search(query, k=top_k, fetch_k=top_k * 2)
            # 手动计算相似度得分
            doc_texts = [doc.page_content for doc in docs]
            doc_vecs = np.array(self.embeddings.embed_documents(doc_texts), dtype=np.float32)
            scores = [float(np.dot(query_vec, dv) / (np.linalg.norm(query_vec) * np.linalg.norm(dv)))
                      for dv in doc_vecs]
        else:
            # 执行相似度搜索
            docs_and_scores = self.db.similarity_search_with_score(query, k=top_k)
            docs, scores = zip(*docs_and_scores)
            docs = list(docs)
            scores = list(scores)
            doc_texts = [doc.page_content for doc in docs]
            doc_vecs = self.embeddings.embed_documents(doc_texts)

        # 将特征向量转换为列表形式
        doc_vecs = [vec.tolist() for vec in doc_vecs]

        return docs, doc_vecs, scores, query_vec

    # 这一段query包含了调用llm生成答案部分，一种是调用ollama部署的llm，一种是调用api并使用agent工具
    # def query(self, question: str) -> str:
    #     """
    #     基于已加载的向量存储进行查询，并返回生成的答案。
    #     """
    #     if self.retriever is None:
    #         raise RuntimeError("Retriever not initialized. Call 'load_vectorstore()' first.")
    #     llm = ChatOllama(model=self.ollama_model)
    #     qa_chain = RetrievalQA.from_chain_type(
    #         llm=llm,
    #         retriever=self.retriever,
    #         chain_type="stuff",
    #         return_source_documents=False
    #     )
    #     result = qa_chain.invoke({"query": question})
    #     return result["result"]
        # 创建检索工具
        # retrieval_tool = create_retriever_tool(
        #     self.retriever,
        #     name="pdf_extractor",
        #     description="Tool to answer queries based on the processed PDF content."
        # )

        # prompt = ChatPromptTemplate.from_messages([
        #     ("system", 
        #      """
        #      你是AI助手，请根据提供的上下文回答问题，确保提供所有细节，
        #      如果答案不在上下文中，请说 '答案不在上下文中'，不要提供错误的答案
        #      """),
        #     ("human", "{input}"),
        #     ("placeholder", "{agent_scratchpad}")
        # ])
        # agent = create_tool_calling_agent(llm, [retrieval_tool], prompt)
        # executor = AgentExecutor(agent=agent, tools=[retrieval_tool], verbose=False)
        # result = executor.invoke({"input": question})
        # return result.get('output', '')


In [15]:
client = Client(vectorstore_path="./common_sense_db")
# client.build_vectorstore(batch_size=10, streaming=False, folder_path="./classified/common_sense")
client.load_vectorstore()

llama_model_loader: loaded meta data with 36 key-value pairs and 310 tensors from ./models/Qwen3-Embedding/Qwen3-Embedding-0.6B-Q8_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen3
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Qwen3 Embedding 0.6b
llama_model_loader: - kv   3:                           general.basename str              = qwen3-embedding
llama_model_loader: - kv   4:                         general.size_label str              = 0.6B
llama_model_loader: - kv   5:                            general.license str              = apache-2.0
llama_model_loader: - kv   6:                   general.base_model.count u32              = 1
llama_model_loader: - kv  

Vectorstore ./common_sense_db loaded.


In [18]:
docs, doc_vecs, scores, q_vec = client.retrieve("What is Chinese", top_k=5, use_mmr=True)

llama_perf_context_print:        load time =     165.61 ms
llama_perf_context_print: prompt eval time =      61.21 ms /     4 tokens (   15.30 ms per token,    65.35 tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =      65.17 ms /     5 tokens
llama_perf_context_print:        load time =     165.61 ms
llama_perf_context_print: prompt eval time =      85.99 ms /     4 tokens (   21.50 ms per token,    46.52 tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =      91.01 ms /     5 tokens
llama_perf_context_print:        load time =     165.61 ms
llama_perf_context_print: prompt eval time =     828.39 ms /   147 tokens (    5.64 ms per token,   177.45 tokens per second)
llama_perf_context_print:        eval time = 