{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RAG评估 - 使用本地文件系统版本\n",
    "\n",
    "这个笔记本使用本地文件系统模式而不是服务器模式来创建和使用Chroma数据库。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 安装必要的依赖\n",
    "# !pip install langchain_huggingface\n",
    "# !pip install tabulate  # 确保安装tabulate用于处理Excel数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 确保安装tabulate\n",
    "try:\n",
    "    import tabulate\n",
    "    print(\"tabulate已安装\")\n",
    "except ImportError:\n",
    "    print(\"正在安装tabulate...\")\n",
    "    !pip install tabulate\n",
    "    print(\"tabulate安装完成\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from model import RagEmbedding, RagLLM\n",
    "from doc_parse import chunk, read_and_process_excel, logger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "from langchain_chroma import Chroma\n",
    "from langchain_text_splitters import RecursiveCharacterTextSplitter\n",
    "import chromadb\n",
    "import os\n",
    "import shutil"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 设置参数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 配置参数 - 使用本地文件系统模式\n",
    "PERSIST_DIRECTORY = \"./chroma_db/zhidu_db\"\n",
    "COLLECTION_NAME = \"zhidu_db\"\n",
    "\n",
    "# 如果目录已存在，删除它以确保重新创建\n",
    "if os.path.exists(PERSIST_DIRECTORY):\n",
    "    shutil.rmtree(PERSIST_DIRECTORY)\n",
    "os.makedirs(PERSIST_DIRECTORY, exist_ok=True)\n",
    "\n",
    "print(f\"使用本地存储目录: {PERSIST_DIRECTORY}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pdf_files = [\"./data/zhidu_employee.pdf\", \"./data/zhidu_travel.pdf\"]\n",
    "excel_files = [\"./data/zhidu_detail.xlsx\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "r_spliter = RecursiveCharacterTextSplitter(\n",
    "    chunk_size=128,\n",
    "    chunk_overlap=30,\n",
    "    separators=[\"\\n\\n\", \n",
    "                \"\\n\", \n",
    "                \".\", \n",
    "                \"\\uff0e\", \n",
    "                \"\\u3002\",\n",
    "                \",\",\n",
    "                \"\\uff0c\",\n",
    "                \"\\u3001'\"\n",
    "                ])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 处理PDF文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "doc_data = []\n",
    "for pdf_file_name in pdf_files:\n",
    "    try:\n",
    "        print(f\"处理PDF文件: {pdf_file_name}\")\n",
    "        res = chunk(pdf_file_name, callback=logger)\n",
    "        for data in res:\n",
    "            content = data[\"content_with_weight\"]\n",
    "            if '<table>' not in content and len(content) > 200:\n",
    "                doc_data = doc_data + r_spliter.split_text(content)\n",
    "            else:\n",
    "                doc_data.append(content)\n",
    "    except Exception as e:\n",
    "        print(f\"处理PDF文件失败: {str(e)}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"处理了 {len(doc_data)} 个PDF文档片段\")\n",
    "for i, chunk_text in enumerate(doc_data[:3]):\n",
    "    print(f\"\\n示例 {i+1}:\")\n",
    "    print(len(chunk_text), \"=\"*10, chunk_text)\n",
    "print(\"...等多个文档片段\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 处理Excel文件"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for excel_file_name in excel_files:\n",
    "    try:\n",
    "        print(f\"处理Excel文件: {excel_file_name}\")\n",
    "        data = read_and_process_excel(excel_file_name)\n",
    "        df = pd.DataFrame(data[8:], columns=data[7])\n",
    "        data_excel = df.drop(columns=df.columns[11:17])\n",
    "        doc_data.append(data_excel.to_markdown(index=False).replace(' ', \"\"))\n",
    "    except Exception as e:\n",
    "        print(f\"处理Excel文件失败: {str(e)}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 创建文档对象"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.documents import Document\n",
    "documents = []\n",
    "\n",
    "for chunk_text in doc_data:\n",
    "    document = Document(\n",
    "        page_content=chunk_text,\n",
    "        metadata={\"source\": \"test\"})\n",
    "    documents.append(document)\n",
    "\n",
    "print(f\"总共创建了 {len(documents)} 个文档对象\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 初始化嵌入模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_cls = RagEmbedding(model_name=\"BAAI/bge-m3\", device=\"cpu\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 创建Chroma数据库（使用本地文件系统模式）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "embedding_db = Chroma.from_documents(\n",
    "    documents=documents,\n",
    "    embedding=embedding_cls,\n",
    "    persist_directory=PERSIST_DIRECTORY,  # 使用本地目录存储\n",
    "    collection_name=COLLECTION_NAME\n",
    ")\n",
    "\n",
    "# 注意：最新版本的Chroma不需要手动调用persist()方法\n",
    "# embedding_db.persist()  # 旧版本需要这一行\n",
    "print(f\"成功创建Chroma数据库并保存至: {PERSIST_DIRECTORY}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试相似性搜索"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"迟到有什么规定？\"\n",
    "print(f\"查询: '{query}'\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "related_docs = embedding_db.similarity_search(query, k=2)\n",
    "\n",
    "print(\"最相关的文档:\")\n",
    "for i, doc in enumerate(related_docs):\n",
    "    print(f\"\\n{i+1}. {doc.page_content}\")\n",
    "    print(f\"   来源: {doc.metadata['source']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 重新加载数据库（演示持久化后如何重新使用）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用正确的方式重新加载已有的数据库\n",
    "loaded_db = Chroma(\n",
    "    collection_name=COLLECTION_NAME,\n",
    "    embedding_function=embedding_cls,\n",
    "    persist_directory=PERSIST_DIRECTORY\n",
    ")\n",
    "\n",
    "# 获取集合中的文档数量\n",
    "collection_size = loaded_db._collection.count()\n",
    "print(f\"成功重新加载数据库，包含 {collection_size} 个文档\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"员工福利有哪些？\"\n",
    "print(f\"新查询: '{query}'\")\n",
    "\n",
    "related_docs = loaded_db.similarity_search(query, k=2)\n",
    "\n",
    "print(\"最相关的文档:\")\n",
    "for i, doc in enumerate(related_docs):\n",
    "    print(f\"\\n{i+1}. {doc.page_content}\")\n",
    "    print(f\"   来源: {doc.metadata['source']}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}