In [None]:
# ===== 代码块 1: 导入和初始化 =====
from __future__ import annotations
from pathlib import Path
from typing import List, Dict, Any, Optional
import os
from langchain_openai import OpenAIEmbeddings
from langchain_text_splitters import MarkdownHeaderTextSplitter
from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from dotenv import load_dotenv, find_dotenv

# 加载环境变量 (自动寻找项目根目录的 .env)
load_dotenv(find_dotenv(), override=True)

# 设置数据目录
DATA_ROOT = Path("data_debug")

print("✅ 模块导入和初始化完成")

In [None]:
# Cell 3: 目录管理函数
def workdir(file_id: str) -> Path:
    p = DATA_ROOT / file_id
    p.mkdir(parents=True, exist_ok=True)
    return p

def markdown_path(file_id: str) -> Path:
    return workdir(file_id) / "output.md"

def index_dir(file_id: str) -> Path:
    p = workdir(file_id) / "index_faiss"
    p.mkdir(parents=True, exist_ok=True)
    return p

print("Path helpers defined.")

In [None]:
# Cell 4: 生成测试 Markdown 文件
FILE_ID = "test_doc_001"

dummy_md_content = """
# 项目介绍
这是一个用于测试 RAG 流程的文档。这里是项目介绍的正文部分。

## 技术栈
我们使用了 Python, LangChain 和 FAISS。
- Python: 编程语言
- FAISS: 向量库

# 部署指南
## 环境准备
请确保安装了 Python 3.10 以上版本。

## 启动服务
运行 `python main.py` 启动服务。
"""

# 写入文件
md_file = markdown_path(FILE_ID)
md_file.write_text(dummy_md_content, encoding="utf-8")
print(f"Created dummy markdown at: {md_file}")

In [None]:
# Cell 5: 测试 Markdown 切分
from langchain_text_splitters import MarkdownHeaderTextSplitter


def split_markdown(md_text: str) -> List[Document]:
    headers_to_split_on = [
        ("#", "Header 1"),
        ("##", "Header 2"),
        # ("###", "Header 3") # 需要更细可以加
    ]
    splitter = MarkdownHeaderTextSplitter(headers_to_split_on=headers_to_split_on)
    docs = splitter.split_text(md_text)
    
    # 清洗逻辑
    cleaned: List[Document] = []
    for d in docs:
        txt = (d.page_content or "").strip()
        if not txt:
            continue
        if len(txt) > 8000:
            txt = txt[:8000]
        cleaned.append(Document(page_content=txt, metadata=d.metadata))
    return cleaned

# 读取刚才生成的 MD 并测试
md_text = markdown_path(FILE_ID).read_text(encoding="utf-8")
docs = split_markdown(md_text)

print(f"Split into {len(docs)} chunks.\n")
for i, doc in enumerate(docs):
    print(f"--- Chunk {i+1} ---")
    print(f"Metadata: {doc.metadata}")
    print(f"Content: {doc.page_content}\n")

In [None]:
# Cell 6: 加载 Embedding 模型 (根据rag_agent.py修改)
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
# from langchain_openai import OpenAIEmbeddings # 原OpenAI实现
# from langchain_community.embeddings import FakeEmbeddings # 仅用于无 Key 调试

def load_embeddings():
    # --- 调试选项：使用本地模型替代OpenAI ---
    # 优先从环境变量读取模型路径/名称
    model_name = os.getenv("EMBEDDING_MODEL_NAME", "BAAI/bge-small-zh-v1.5")
    print(f"正在加载 Embedding 模型: {model_name} ...")
    try:
        # 使用与rag_agent.py相同的本地模型
        embed = HuggingFaceBgeEmbeddings(
            model_name=os.getenv("EMBEDDING_MODEL_NAME", "BAAI/bge-small-zh-v1.5"),
            model_kwargs={'device': 'cpu'},  # 如果有GPU可以改为'cuda'
            encode_kwargs={'normalize_embeddings': True}
        )
        return embed
    except Exception as e:
        print(f"❌ 本地模型加载失败: {e}")
        
        # 回退到OpenAI（如果有API Key）
        api_key = os.getenv("OPENAI_API_KEY") or os.getenv("OPENAI_EMBEDDING_API_KEY")
        base_url = os.getenv("OPENAI_BASE_URL") or os.getenv("OPENAI_EMBEDDING_BASE_URL")
        
        if api_key:
            print("回退到OpenAI Embeddings...")
            kwargs = {"api_key": api_key}
            if base_url:
                kwargs["base_url"] = base_url
            return OpenAIEmbeddings(model="text-embedding-3-large", **kwargs)
        else:
            # 最后回退到假嵌入（仅用于测试）
            print("⚠️ Warning: No API Key found. Using FakeEmbeddings (dimension=512).")
            from langchain_community.embeddings import FakeEmbeddings 
            return FakeEmbeddings(size=512) # 注意：bge-small-zh-v1.5的维度是512

# 测试加载
emb_model = load_embeddings()
# 试着 embed 一个词测试连通性
try:
    vec = emb_model.embed_query("test")
    print(f"Embedding check passed. Vector dimension: {len(vec)}")
except Exception as e:
    print(f"❌ Embedding connection failed: {e}")

In [None]:
# Cell 7: 构建并保存索引
from langchain_community.vectorstores import FAISS

def build_faiss_index(file_id: str) -> Dict[str, Any]:
    md_file = markdown_path(file_id)
    if not md_file.exists():
        return {"ok": False, "error": "MARKDOWN_NOT_FOUND"}
    
    md_text = md_file.read_text(encoding="utf-8")
    docs = split_markdown(md_text)
    
    if not docs:
        return {"ok": False, "error": "EMPTY_MD"}

    print(f"Vectorizing {len(docs)} documents...")
    embeddings = load_embeddings()
    try:
        vs = FAISS.from_documents(docs, embedding=embeddings)
        save_path = index_dir(file_id)
        vs.save_local(str(save_path))
        print(f"Index saved to {save_path}")
        return {"ok": True, "chunks": len(docs)}
    except Exception as e:
        print(f"Error building index: {e}")
        return {"ok": False, "error": str(e)}

# 运行构建
build_result = build_faiss_index(FILE_ID)
print("Build Result:", build_result)

In [None]:
# Cell 8: 搜索测试
def search_faiss(file_id: str, query: str, k: int = 5) -> Dict[str, Any]:
    idx = index_dir(file_id)
    if not (idx / "index.faiss").exists():
        return {"ok": False, "error": "INDEX_NOT_FOUND"}

    print(f"Loading index from {idx} ...")
    embeddings = load_embeddings()
    
    # allow_dangerous_deserialization=True 是必须的，因为 FAISS 使用 pickle
    vs = FAISS.load_local(str(idx), embeddings, allow_dangerous_deserialization=True)
    
    print(f"Searching for: '{query}'")
    hits = vs.similarity_search_with_score(query, k=k)
    
    results = []
    for doc, score in hits:
        results.append({
            "text": doc.page_content,
            "score": float(score), # FAISS L2 距离，越小越相似
            "metadata": doc.metadata,
        })
    return {"ok": True, "results": results}

# 运行搜索
# 尝试搜索 "环境" 相关的词
search_result = search_faiss(FILE_ID, "如何准备环境？", k=2)

if search_result["ok"]:
    for i, item in enumerate(search_result["results"]):
        print(f"\nResult {i+1} (Score: {item['score']:.4f}):")
        print(f"Text: {item['text']}")
        print(f"Meta: {item['metadata']}")
else:
    print("Search failed:", search_result)