In [27]:
import os
from langchain.document_loaders import PyPDFLoader
from typing import List, Tuple, Dict, Any
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.vectorstores import FAISS
from langchain.tools.retriever import create_retriever_tool
from langchain.agents import AgentExecutor, create_tool_calling_agent
from langchain_community.embeddings import DashScopeEmbeddings
from langchain.chains import RetrievalQA
from datasets import load_dataset
from langchain.schema import Document
import numpy as np
import json
import faiss
from langchain.chat_models import ChatOllama

# DASHSCOPE_API_KEY = os.getenv("DASHSCOPE_API_KEY")

In [28]:
from llama_cpp import Llama
from langchain.embeddings.base import Embeddings

# 自定义 LangChain 的 Embeddings 类封装
class LlamaCppEmbeddings(Embeddings):
    def __init__(self, model_path: str):
        self.llm = Llama(model_path=model_path, embedding=True)

    def embed_documents(self, texts: list[str]):
        # return [self.llm.embed(text)["data"][0]["embedding"] for text in texts]
        embeddings = []
        for text in texts:
            result = self.llm.embed(text)
            if isinstance(result, list) and isinstance(result[0], list):
                embeddings.append(result[0])
            else:
                embeddings.append(result)
        return embeddings

    def embed_query(self, text):
        # return self.llm.embed(text)["data"][0]["embedding"]
        result = self.llm.embed(text)
        return result[0] if isinstance(result, list) and isinstance(result[0], list) else result

In [29]:
class Proof():
    def __init__(self, document: Document, vector: List[np.ndarray], score: float):
        self.document = document
        self.vector = vector
        self.score = score
        

In [None]:
class Client:
    """
    轻量级rag客户端，负责数据集加载、向量存储构建与检索。
    """
    def __init__(self, model_path: str = "./models/Qwen3-Embedding/Qwen3-Embedding-0.6B-Q8_0.gguf", 
                vectorstore_path: str = "faiss_db"): # dashscope_api_key: str,使用api调用embedding模型
        os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE")
        self.vectorstore_path = vectorstore_path
        # self.embeddings = DashScopeEmbeddings(
        #     model="text-embedding-v1",
        #     dashscope_api_key=dashscope_api_key
        # )
        self.embeddings = LlamaCppEmbeddings(model_path=model_path)
        self.db: FAISS = None
        self.retriever = None
        self.load_vectorstore()

    def _chunk_text(self, text: str, chunk_size=800, overlap= 200) -> list[str]:
        """
        将文本分块处理，使用递归字符分割器。
        """
        splitter = RecursiveCharacterTextSplitter(
            chunk_size=chunk_size,
            chunk_overlap=overlap,
            length_function=len
        )
        return splitter.split_text(text)

    # 读取PDF文件并提取文本内容
    def _read_pdfs(self, pdf_paths: List[str]) -> List[Document]:
        docs = []
        for path in pdf_paths:
            loader = PyPDFLoader(path)
            pages = loader.load_and_split()
            docs.extend(pages)
        return docs

    # 读取JSON文件夹中的所有文件
    def _load_json_folder(self, folder_path: str) -> List[Document]:
        docs = []
        for filename in os.listdir(folder_path):
            if not filename.endswith('.json'):
                continue
            filepath = os.path.join(folder_path, filename)
            with open(filepath, encoding='utf-8') as f:
                data = json.load(f)
            content = f"{data.get('title', '')}\n{data.get('content', '')}".strip()
            if content:
                docs.append(Document(page_content=content, metadata={'source': filepath}))
        return docs
    
    # 在线读取数据集
    def _streaming_load_dataset(self, sample_size=100, language='en', date_version='20231101') -> List[str]:
        # 启用streaming模式在线读取huggingface datasets
        dataset = load_dataset("wikimedia/wikipedia", f'{date_version}.{language}', streaming=True)
        docs = []
        for i, item in enumerate(dataset['train']):
            if i >= sample_size:
                break
            text = item.get('text', '')
            title = item.get('title', '')
            if not text:
                continue
            # # 抽取前 5000 字，避免过长
            # snippet = text[:5000]
            meta = {'source': f'wikipedia://{language}/{item.get("id")}'}
            docs.append(Document(page_content=f"{title}\n{text}", metadata=meta))
        print(f"Streamed {len(docs)} Wikipedia docs.")
        return docs
    
    def build_vectorstore(self, sample_size=100, batch_size=10, 
                          streaming=False, folder_path=None, pdf_paths:List[str]=None):
        docs = []
        if streaming:
            # 在线读取数据集
            docs.extend(self._streaming_load_dataset(sample_size))
        elif folder_path is not None and pdf_paths is None:
            # 从指定文件夹加载JSON文件
            docs.extend(self._load_json_folder(folder_path))
        elif pdf_paths is not None:
            # 从PDF文件加载
            docs.extend(self._read_pdfs(pdf_paths))

        texts, metadatas = [], []
        faiss_id = 0
        # 分块并批量处理
        for i, doc in enumerate(docs):
            chunks = self._chunk_text(doc.page_content)
            for j, chunk in enumerate(chunks):
                texts.append(chunk)
                metadatas.append({
                "source": doc.metadata.get("source", ""),
                "doc_id": i,
                "chunk_id": j,
                "faiss_id": faiss_id
            })
                faiss_id += 1
                # 每 batch_size 保存一次，防止内存溢出
                if len(texts) >= batch_size or j == len(chunks) - 1:
                    if self.db is None:
                        self.db = FAISS.from_texts(texts, embedding=self.embeddings, metadatas=metadatas)
                    else:
                        self.db.add_texts(texts, metadatas=metadatas)
                    texts.clear()
                    metadatas.clear()
            print(f"Processed {i+1}/{len(docs)} articles...")

        # 保存向量库
        if self.db:
            self.db.save_local(self.vectorstore_path)
            print(f"Vectorstore saved to {self.vectorstore_path}")
        else:
            print("No data processed.")

    def load_vectorstore(self) -> None:
        """
        加载已保存的向量存储
        """
        if not os.path.exists(self.vectorstore_path):
            raise FileNotFoundError(f"Vectorstore directory '{self.vectorstore_path}' not found.")
        self.db = FAISS.load_local(
            self.vectorstore_path,
            embeddings=self.embeddings,
            allow_dangerous_deserialization=True
        )
        print(f"Vectorstore {self.vectorstore_path} loaded.")

    def retrieve(self, query:str, top_k=4):
        """
        通过query在FAISS向量库中检索k个最相似文档，
        返回每个Document对象、其特征向量及相似度得分
        """
        # 检查向量库是否已加载
        if self.db is None:
            raise ValueError("Vectorstore尚未加载，请先调用load_vectorstore或build_vectorstore")

        query_vec = np.array(self.embeddings.embed_query(query), dtype=np.float32)
        query_vec.tolist()
        
        # 执行MMR搜索
        # if use_mmr: 
        #     docs = self.db.max_marginal_relevance_search(query, k=top_k, fetch_k=top_k * 2)
        #     doc_texts = [doc.page_content for doc in docs]
        #     doc_vecs = np.array(self.embeddings.embed_documents(doc_texts), dtype=np.float32)
        # else:

        # 执行相似度搜索
        docs_and_scores = self.db.similarity_search_with_score(query, k=top_k)
        docs, scores = zip(*docs_and_scores)
        docs = list(docs)
        scores = list(scores)
        doc_texts = [doc.page_content for doc in docs]
        doc_vecs = self.embeddings.embed_documents(doc_texts)

        # 打包结果为列表 [(doc, vec, score)]
        results = [Proof(doc, vec, score) for doc, vec, score in zip(docs, doc_vecs, scores)]

        return results, query_vec

In [31]:
class Server:
    """
    Server 类，负责：
    1) 接收客户端选择的上下文数据
    2) 验证数据完整性（通过 Proof 信息）
    3) 调用 Ollama 部署的 Qwen3:4B 模型生成答案
    """
    def __init__(self, model_name: str = "qwen3:4b"):
        self.llm = ChatOllama(model=model_name)

    def verify_documents(self):
        return

    def build_prompt(self, query: str, contexts: List[str]) -> str:
        """构造 Prompt，将 query 和上下文拼接"""
        prompt = "You are an AI assistant. Use the following contexts to answer the question:\n"
        for i, c in enumerate(contexts, 1):
            prompt += f"Context {i}: {c}\n"
        prompt += f"Question: {query}\nAnswer:"
        return prompt

    def generate_answer(self, query: str, contexts: List[str]) -> str:
        """
        验证 Proof 后调用模型生成答案
        """
        # if not self.verify_documents(contexts, proofs):
        #     raise ValueError("Proof verification failed! Data may be tampered.")
        prompt = self.build_prompt(query, contexts)
        response = self.llm.predict(prompt)
        return response

In [32]:
server = Server()

In [33]:
common_sense = Client(vectorstore_path="./common_sense_db")
common_sense.load_vectorstore()

llama_model_loader: loaded meta data with 36 key-value pairs and 310 tensors from ./models/Qwen3-Embedding/Qwen3-Embedding-0.6B-Q8_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen3
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Qwen3 Embedding 0.6b
llama_model_loader: - kv   3:                           general.basename str              = qwen3-embedding
llama_model_loader: - kv   4:                         general.size_label str              = 0.6B
llama_model_loader: - kv   5:                            general.license str              = apache-2.0
llama_model_loader: - kv   6:                   general.base_model.count u32              = 1
llama_model_loader: - kv  

Vectorstore ./common_sense_db loaded.


In [34]:
computer_science = Client(vectorstore_path="./computer_science_coding_related_db")
computer_science.load_vectorstore()

llama_model_loader: loaded meta data with 36 key-value pairs and 310 tensors from ./models/Qwen3-Embedding/Qwen3-Embedding-0.6B-Q8_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen3
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Qwen3 Embedding 0.6b
llama_model_loader: - kv   3:                           general.basename str              = qwen3-embedding
llama_model_loader: - kv   4:                         general.size_label str              = 0.6B
llama_model_loader: - kv   5:                            general.license str              = apache-2.0
llama_model_loader: - kv   6:                   general.base_model.count u32              = 1
llama_model_loader: - kv  

Vectorstore ./computer_science_coding_related_db loaded.


In [35]:
law = Client(vectorstore_path="./law_related_db")
law.load_vectorstore()

llama_model_loader: loaded meta data with 36 key-value pairs and 310 tensors from ./models/Qwen3-Embedding/Qwen3-Embedding-0.6B-Q8_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen3
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Qwen3 Embedding 0.6b
llama_model_loader: - kv   3:                           general.basename str              = qwen3-embedding
llama_model_loader: - kv   4:                         general.size_label str              = 0.6B
llama_model_loader: - kv   5:                            general.license str              = apache-2.0
llama_model_loader: - kv   6:                   general.base_model.count u32              = 1
llama_model_loader: - kv  

Vectorstore ./law_related_db loaded.


In [36]:
medicine = Client(vectorstore_path="./medicine_related_db")
medicine.load_vectorstore()

llama_model_loader: loaded meta data with 36 key-value pairs and 310 tensors from ./models/Qwen3-Embedding/Qwen3-Embedding-0.6B-Q8_0.gguf (version GGUF V3 (latest))
llama_model_loader: Dumping metadata keys/values. Note: KV overrides do not apply in this output.
llama_model_loader: - kv   0:                       general.architecture str              = qwen3
llama_model_loader: - kv   1:                               general.type str              = model
llama_model_loader: - kv   2:                               general.name str              = Qwen3 Embedding 0.6b
llama_model_loader: - kv   3:                           general.basename str              = qwen3-embedding
llama_model_loader: - kv   4:                         general.size_label str              = 0.6B
llama_model_loader: - kv   5:                            general.license str              = apache-2.0
llama_model_loader: - kv   6:                   general.base_model.count u32              = 1
llama_model_loader: - kv  

Vectorstore ./medicine_related_db loaded.


In [37]:
# 模拟用户查询
query = "What is the capital of France?"

In [38]:
result1, query_vec1 = common_sense.retrieve(query)
result2, query_vec2 = computer_science.retrieve(query)
result3, query_vec3 = law.retrieve(query)
result4, query_vec4 = medicine.retrieve(query)

llama_perf_context_print:        load time =     142.02 ms
llama_perf_context_print: prompt eval time =     138.81 ms /     8 tokens (   17.35 ms per token,    57.63 tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =     142.20 ms /     9 tokens
llama_perf_context_print:        load time =     142.02 ms
llama_perf_context_print: prompt eval time =      57.83 ms /     8 tokens (    7.23 ms per token,   138.35 tokens per second)
llama_perf_context_print:        eval time =       0.00 ms /     1 runs   (    0.00 ms per token,      inf tokens per second)
llama_perf_context_print:       total time =      61.67 ms /     9 tokens
llama_perf_context_print:        load time =     142.02 ms
llama_perf_context_print: prompt eval time =     827.52 ms /   147 tokens (    5.63 ms per token,   177.64 tokens per second)
llama_perf_context_print:        eval time = 

In [39]:
results = result1 + result2 + result3 + result4
results.sort(key=lambda r: r.score, reverse=True)

In [42]:
contexts = [r.document.page_content for r in results[:5]]

In [44]:
server.generate_answer(query, contexts)

  response = self.llm.predict(prompt)


"<think>\nOkay, the user is asking for the capital of France. Let me check the provided contexts to see if any of them mention France's capital.\n\nLooking through Context 1: It talks about the acre being used in some Commonwealth countries, but not France's capital. Context 2: Mentions Normandy and Paris as units of area in France, but not the capital. Context 3: Discusses land divisions in Canada and the US, not France. Context 4: Talks about macOS and AbiWord, unrelated. Context 5: Discusses the concept of the state, not geographical information.\n\nNone of the contexts provided mention the capital of France. The answer isn't in the given texts. I should inform the user that the information isn't available here and provide the correct answer directly.\n</think>\n\nThe capital of France is Paris. \n\n(Note: None of the provided contexts mention the capital of France. This answer is based on general knowledge.)"