In [None]:
# notebooks/exploration.ipynb
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# NL2SQL 数据探索笔记本\n",
    "\n",
    "这个笔记本用于探索数据库结构和测试Vanna的NL2SQL功能。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 导入必要的模块\n",
    "import os\n",
    "import sys\n",
    "import pandas as pd\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "# 将项目根目录添加到Python路径\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)\n",
    "\n",
    "# 导入项目模块\n",
    "from app.config import DATABASE_CONFIG\n",
    "from app.db.connection import DatabaseConnection\n",
    "from app.langchain.llm_config import LLMFactory, EmbeddingFactory\n",
    "from app.vanna.setup import VannaSetup\n",
    "from app.vanna.query_processor import QueryProcessor"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 连接数据库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 初始化数据库连接\n",
    "db_connection = DatabaseConnection()\n",
    "print(f\"已连接到数据库：{DATABASE_CONFIG['dbname']} at {DATABASE_CONFIG['host']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 探索数据库结构"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 获取数据库结构\n",
    "ddl = db_connection.get_database_schema()\n",
    "print(ddl)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 初始化Vanna"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 创建LLM和Embedding模型\n",
    "llm_model = LLMFactory.create_llm()\n",
    "embedding_model = EmbeddingFactory.create_embedding()\n",
    "\n",
    "# 初始化Vanna\n",
    "vanna_setup = VannaSetup(llm_model, embedding_model)\n",
    "vanna_instance = vanna_setup.initialize_vanna()\n",
    "\n",
    "# 连接数据库\n",
    "vanna_setup.connect_to_database(db_connection)\n",
    "\n",
    "# 创建查询处理器\n",
    "query_processor = QueryProcessor(vanna_instance, db_connection)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试简单查询"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 测试简单查询\n",
    "question = \"列出所有客户\"\n",
    "result = query_processor.process_query(question)\n",
    "\n",
    "if result['success']:\n",
    "    print(f\"生成的SQL: {result['sql']}\\n\")\n",
    "    df = pd.DataFrame(result['results'])\n",
    "    display(df)\n",
    "else:\n",
    "    print(f\"查询失败: {result['error']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试条件过滤查询"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 测试条件过滤查询\n",
    "question = \"查找纽约的客户\"\n",
    "result = query_processor.process_query(question)\n",
    "\n",
    "if result['success']:\n",
    "    print(f\"生成的SQL: {result['sql']}\\n\")\n",
    "    df = pd.DataFrame(result['results'])\n",
    "    display(df)\n",
    "else:\n",
    "    print(f\"查询失败: {result['error']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试多表连接查询"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 测试多表连接查询\n",
    "question = \"显示每个订单的客户名称\"\n",
    "result = query_processor.process_query(question)\n",
    "\n",
    "if result['success']:\n",
    "    print(f\"生成的SQL: {result['sql']}\\n\")\n",
    "    df = pd.DataFrame(result['results'])\n",
    "    display(df)\n",
    "else:\n",
    "    print(f\"查询失败: {result['error']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试复杂查询"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 测试复杂查询\n",
    "question = \"查找购买金额最高的前5名客户\"\n",
    "result = query_processor.process_query(question)\n",
    "\n",
    "if result['success']:\n",
    "    print(f\"生成的SQL: {result['sql']}\\n\")\n",
    "    df = pd.DataFrame(result['results'])\n",
    "    display(df)\n",
    "    \n",
    "    # 可视化结果\n",
    "    plt.figure(figsize=(10, 6))\n",
    "    sns.barplot(x='LastName', y='TotalAmount', data=df)\n",
    "    plt.title('购买金额最高的前5名客户')\n",
    "    plt.xlabel('客户')\n",
    "    plt.ylabel('总金额')\n",
    "    plt.xticks(rotation=45)\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "else:\n",
    "    print(f\"查询失败: {result['error']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 添加训练数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 添加自定义训练数据\n",
    "vanna_instance.train(\n",
    "    question=\"查找去年销售额增长最快的产品\",\n",
    "    sql=\"\"\"\n",
    "    SELECT \n",
    "        p.ProductName,\n",
    "        SUM(CASE WHEN EXTRACT(YEAR FROM o.OrderDate) = EXTRACT(YEAR FROM CURRENT_DATE) - 1 THEN oi.Quantity * p.Price ELSE 0 END) as CurrentYearSales,\n",
    "        SUM(CASE WHEN EXTRACT(YEAR FROM o.OrderDate) = EXTRACT(YEAR FROM CURRENT_DATE) - 2 THEN oi.Quantity * p.Price ELSE 0 END) as PreviousYearSales,\n",
    "        (SUM(CASE WHEN EXTRACT(YEAR FROM o.OrderDate) = EXTRACT(YEAR FROM CURRENT_DATE) - 1 THEN oi.Quantity * p.Price ELSE 0 END) - \n",
    "         SUM(CASE WHEN EXTRACT(YEAR FROM o.OrderDate) = EXTRACT(YEAR FROM CURRENT_DATE) - 2 THEN oi.Quantity * p.Price ELSE 0 END)) /\n",
    "         NULLIF(SUM(CASE WHEN EXTRACT(YEAR FROM o.OrderDate) = EXTRACT(YEAR FROM CURRENT_DATE) - 2 THEN oi.Quantity * p.Price ELSE 0 END), 0) * 100 as GrowthRate\n",
    "    FROM \n",
    "        Products p\n",
    "        JOIN OrderItems oi ON p.ProductID = oi.ProductID\n",
    "        JOIN Orders o ON oi.OrderID = o.OrderID\n",
    "    WHERE \n",
    "        EXTRACT(YEAR FROM o.OrderDate) >= EXTRACT(YEAR FROM CURRENT_DATE) - 2\n",
    "    GROUP BY \n",
    "        p.ProductID, p.ProductName\n",
    "    HAVING \n",
    "        SUM(CASE WHEN EXTRACT(YEAR FROM o.OrderDate) = EXTRACT(YEAR FROM CURRENT_DATE) - 2 THEN oi.Quantity * p.Price ELSE 0 END) > 0\n",
    "    ORDER BY \n",
    "        GrowthRate DESC\n",
    "    LIMIT 10\n",
    "    \"\"\"\n",
    ")\n",
    "\n",
    "print(\"自定义训练数据已添加\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 测试自定义查询"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 测试我们刚刚训练的查询\n",
    "question = \"查找去年销售额增长最快的产品\"\n",
    "result = query_processor.process_query(question)\n",
    "\n",
    "if result['success']:\n",
    "    print(f\"生成的SQL: {result['sql']}\\n\")\n",
    "    df = pd.DataFrame(result['results'])\n",
    "    display(df)\n",
    "    \n",
    "    # 可视化结果\n",
    "    if not df.empty and 'GrowthRate' in df.columns:\n",
    "        plt.figure(figsize=(12, 8))\n",
    "        sns.barplot(x='ProductName', y='GrowthRate', data=df)\n",
    "        plt.title('销售额增长最快的产品')\n",
    "        plt.xlabel('产品')\n",
    "        plt.ylabel('增长率(%)')\n",
    "        plt.xticks(rotation=45)\n",
    "        plt.tight_layout()\n",
    "        plt.show()\n",
    "else:\n",
    "    print(f\"查询失败: {result['error']}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 探索数据分布"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "source": [
    "# 获取订单数据\n",
    "result = db_connection.execute_query(\"SELECT * FROM Orders\")\n",
    "orders_df = pd.DataFrame(result[0])\n",
    "\n",
    "# 按月统计订单数量\n",
    "if not orders_df.empty and 'OrderDate' in orders_df.columns:\n",
    "    orders_df['OrderDate'] = pd.to_datetime(orders_df['OrderDate'])\n",
    "    orders_df['YearMonth'] = orders_df['OrderDate'].dt.to_period('M')\n",
    "    \n",
    "    monthly_orders = orders_df.groupby('YearMonth').size().reset_index(name='OrderCount')\n",
    "    monthly_orders['YearMonth'] = monthly_orders['YearMonth'].astype(str)\n",
    "    \n",
    "    plt.figure(figsize=(12, 6))\n",
    "    sns.lineplot(x='YearMonth', y='OrderCount', data=monthly_orders, marker='o')\n",
    "    plt.title('月度订单数量趋势')\n",
    "    plt.xlabel('年月')\n",
    "    plt.ylabel('订单数量')\n",
    "    plt.xticks(rotation=45)\n",
    "    plt.tight_layout()\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 总结\n",
    "\n",
    "在这个笔记本中，我们探索了数据库结构，测试了Vanna的NL2SQL功能，并分析了查询结果。我们看到Vanna能够处理从简单到复杂的各种查询，并且可以通过训练不断改进。"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}