# TextSQL RAG Pipeline 學習筆記

本筆記本將詳細介紹如何構建一個完整的 textSQL RAG (Retrieval-Augmented Generation) 流水線。

## 目錄
1. [RAG 概念介紹](#1-rag-概念介紹)
2. [TextSQL 基礎](#2-textsql-基礎)
3. [環境設置](#3-環境設置)
4. [數據預處理](#4-數據預處理)
5. [向量化與索引](#5-向量化與索引)
6. [檢索系統](#6-檢索系統)
7. [SQL 生成](#7-sql-生成)
8. [完整流水線](#8-完整流水線)
9. [評估與優化](#9-評估與優化)
10. [實際應用案例](#10-實際應用案例)

## 1. RAG 概念介紹

### 什麼是 RAG？
RAG（Retrieval-Augmented Generation）是一種結合了檢索和生成的架構：
1. **檢索階段**：從知識庫中找到相關信息
2. **生成階段**：基於檢索到的信息生成回答

### TextSQL RAG 的特點
- 專注於自然語言到 SQL 查詢的轉換
- 結合數據庫模式信息
- 支持複雜的數據庫查詢生成

In [None]:
# 基本概念示例
import pandas as pd
import numpy as np
from typing import List, Dict, Any

# RAG 流程示意
class SimpleRAGConcept:
    def __init__(self):
        self.knowledge_base = []
        self.query_history = []
    
    def retrieve(self, query: str) -> List[str]:
        """檢索相關信息"""
        # 這裡是簡化的檢索邏輯
        relevant_docs = [doc for doc in self.knowledge_base 
                        if any(word.lower() in doc.lower() for word in query.split())]
        return relevant_docs[:3]  # 返回前3個相關文檔
    
    def generate(self, query: str, context: List[str]) -> str:
        """基於上下文生成回答"""
        # 這裡是簡化的生成邏輯
        return f"基於上下文 {context} 對查詢 '{query}' 的回答"
    
    def rag_pipeline(self, query: str) -> str:
        """完整的RAG流水線"""
        # 1. 檢索
        relevant_context = self.retrieve(query)
        # 2. 生成
        response = self.generate(query, relevant_context)
        # 3. 記錄查詢歷史
        self.query_history.append({'query': query, 'response': response})
        return response

# 示例使用
rag_demo = SimpleRAGConcept()
rag_demo.knowledge_base = [
    "用戶表包含用戶ID、姓名、郵箱等字段",
    "訂單表記錄了所有的購買信息",
    "產品表存儲產品的詳細信息"
]

result = rag_demo.rag_pipeline("如何查詢用戶的訂單信息？")
print("RAG 示例結果:")
print(result)

## 2. TextSQL 基礎

### 核心概念
TextSQL 是將自然語言查詢轉換為 SQL 語句的過程。關鍵組件包括：
- **Schema Understanding**: 理解數據庫結構
- **Intent Recognition**: 識別用戶意圖
- **SQL Generation**: 生成對應的SQL語句

In [None]:
# TextSQL 基礎組件
class DatabaseSchema:
    """數據庫模式類"""
    def __init__(self):
        self.tables = {}
        self.relationships = []
    
    def add_table(self, table_name: str, columns: Dict[str, str]):
        """添加表結構"""
        self.tables[table_name] = columns
    
    def add_relationship(self, table1: str, column1: str, table2: str, column2: str):
        """添加表關係"""
        self.relationships.append({
            'from_table': table1,
            'from_column': column1,
            'to_table': table2,
            'to_column': column2
        })
    
    def get_schema_info(self) -> str:
        """獲取模式信息"""
        schema_info = "數據庫模式信息:\n"
        for table, columns in self.tables.items():
            schema_info += f"表 {table}: {columns}\n"
        return schema_info

# 創建示例數據庫模式
schema = DatabaseSchema()
schema.add_table('users', {
    'user_id': 'INT PRIMARY KEY',
    'name': 'VARCHAR(100)',
    'email': 'VARCHAR(100)',
    'created_at': 'DATETIME'
})

schema.add_table('orders', {
    'order_id': 'INT PRIMARY KEY',
    'user_id': 'INT',
    'product_id': 'INT',
    'quantity': 'INT',
    'order_date': 'DATETIME'
})

schema.add_table('products', {
    'product_id': 'INT PRIMARY KEY',
    'name': 'VARCHAR(100)',
    'price': 'DECIMAL(10,2)',
    'category': 'VARCHAR(50)'
})

schema.add_relationship('orders', 'user_id', 'users', 'user_id')
schema.add_relationship('orders', 'product_id', 'products', 'product_id')

print(schema.get_schema_info())

## 3. 環境設置

### 安裝必要的庫

In [None]:
# 在 Kaggle 環境中安裝必要的包
import sys
import subprocess

def install_package(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# 核心包列表
packages = [
    'transformers',
    'torch',
    'sentence-transformers',
    'faiss-cpu',
    'chromadb',
    'langchain',
    'openai',
    'sqlparse',
    'sqlite3'
]

print("正在安裝必要的包...")
for package in packages:
    try:
        install_package(package)
        print(f"✓ {package} 安裝成功")
    except Exception as e:
        print(f"✗ {package} 安裝失敗: {e}")

print("\n環境設置完成！")

In [None]:
# 導入必要的庫
import os
import json
import sqlite3
import pandas as pd
import numpy as np
from typing import List, Dict, Any, Tuple
import warnings
warnings.filterwarnings('ignore')

# 向量化和檢索
from sentence_transformers import SentenceTransformer
import faiss

# 自然語言處理
from transformers import AutoTokenizer, AutoModel
import torch

# SQL 解析
import sqlparse

print("所有必要的庫已成功導入！")

## 4. 數據預處理

### 創建示例數據庫和數據

In [None]:
# 創建示例 SQLite 數據庫
def create_sample_database():
    """創建示例數據庫"""
    conn = sqlite3.connect('sample_ecommerce.db')
    cursor = conn.cursor()
    
    # 創建用戶表
    cursor.execute('''
    CREATE TABLE IF NOT EXISTS users (
        user_id INTEGER PRIMARY KEY,
        name TEXT NOT NULL,
        email TEXT UNIQUE NOT NULL,
        created_at DATETIME DEFAULT CURRENT_TIMESTAMP
    )
    ''')
    
    # 創建產品表
    cursor.execute('''
    CREATE TABLE IF NOT EXISTS products (
        product_id INTEGER PRIMARY KEY,
        name TEXT NOT NULL,
        price DECIMAL(10,2) NOT NULL,
        category TEXT NOT NULL,
        description TEXT
    )
    ''')
    
    # 創建訂單表
    cursor.execute('''
    CREATE TABLE IF NOT EXISTS orders (
        order_id INTEGER PRIMARY KEY,
        user_id INTEGER,
        product_id INTEGER,
        quantity INTEGER NOT NULL,
        order_date DATETIME DEFAULT CURRENT_TIMESTAMP,
        FOREIGN KEY (user_id) REFERENCES users (user_id),
        FOREIGN KEY (product_id) REFERENCES products (product_id)
    )
    ''')
    
    # 插入示例數據
    users_data = [
        (1, '張三', 'zhang@example.com'),
        (2, '李四', 'li@example.com'),
        (3, '王五', 'wang@example.com')
    ]
    
    products_data = [
        (1, 'iPhone 15', 999.99, '電子產品', '最新款智能手機'),
        (2, 'MacBook Pro', 1299.99, '電子產品', '專業筆記本電腦'),
        (3, '咖啡機', 199.99, '家電', '全自動咖啡機'),
        (4, '書籍：Python編程', 29.99, '圖書', 'Python編程入門教程')
    ]
    
    orders_data = [
        (1, 1, 1, 1),
        (2, 1, 3, 1),
        (3, 2, 2, 1),
        (4, 3, 4, 2)
    ]
    
    cursor.executemany('INSERT OR REPLACE INTO users (user_id, name, email) VALUES (?, ?, ?)', users_data)
    cursor.executemany('INSERT OR REPLACE INTO products (product_id, name, price, category, description) VALUES (?, ?, ?, ?, ?)', products_data)
    cursor.executemany('INSERT OR REPLACE INTO orders (order_id, user_id, product_id, quantity) VALUES (?, ?, ?, ?)', orders_data)
    
    conn.commit()
    conn.close()
    print("示例數據庫創建完成！")

# 創建數據庫
create_sample_database()

# 驗證數據
def verify_database():
    conn = sqlite3.connect('sample_ecommerce.db')
    
    print("用戶表:")
    users_df = pd.read_sql_query('SELECT * FROM users', conn)
    print(users_df)
    
    print("\n產品表:")
    products_df = pd.read_sql_query('SELECT * FROM products', conn)
    print(products_df)
    
    print("\n訂單表:")
    orders_df = pd.read_sql_query('SELECT * FROM orders', conn)
    print(orders_df)
    
    conn.close()

verify_database()

### 準備訓練數據集

In [None]:
# 創建自然語言到SQL的訓練數據集
training_data = [
    {
        "natural_language": "顯示所有用戶的信息",
        "sql": "SELECT * FROM users;",
        "explanation": "查詢用戶表中的所有記錄"
    },
    {
        "natural_language": "找出價格超過500元的產品",
        "sql": "SELECT * FROM products WHERE price > 500;",
        "explanation": "使用WHERE子句過濾價格大於500的產品"
    },
    {
        "natural_language": "統計每個用戶的訂單數量",
        "sql": "SELECT u.name, COUNT(o.order_id) as order_count FROM users u LEFT JOIN orders o ON u.user_id = o.user_id GROUP BY u.user_id, u.name;",
        "explanation": "使用JOIN和GROUP BY統計每個用戶的訂單數量"
    },
    {
        "natural_language": "查找最貴的產品",
        "sql": "SELECT * FROM products ORDER BY price DESC LIMIT 1;",
        "explanation": "使用ORDER BY和LIMIT找到價格最高的產品"
    },
    {
        "natural_language": "顯示用戶張三的所有訂單",
        "sql": "SELECT o.*, p.name as product_name FROM orders o JOIN users u ON o.user_id = u.user_id JOIN products p ON o.product_id = p.product_id WHERE u.name = '張三';",
        "explanation": "使用多表JOIN查詢特定用戶的訂單信息"
    },
    {
        "natural_language": "計算每個類別的產品平均價格",
        "sql": "SELECT category, AVG(price) as avg_price FROM products GROUP BY category;",
        "explanation": "使用GROUP BY和AVG函數計算各類別的平均價格"
    },
    {
        "natural_language": "找出沒有下過訂單的用戶",
        "sql": "SELECT u.* FROM users u LEFT JOIN orders o ON u.user_id = o.user_id WHERE o.user_id IS NULL;",
        "explanation": "使用LEFT JOIN和IS NULL找出沒有訂單的用戶"
    },
    {
        "natural_language": "顯示銷量最高的產品",
        "sql": "SELECT p.name, SUM(o.quantity) as total_sold FROM products p JOIN orders o ON p.product_id = o.product_id GROUP BY p.product_id, p.name ORDER BY total_sold DESC LIMIT 1;",
        "explanation": "統計產品銷量並找出銷量最高的產品"
    }
]

# 保存訓練數據
with open('training_data.json', 'w', encoding='utf-8') as f:
    json.dump(training_data, f, ensure_ascii=False, indent=2)

print(f"訓練數據集包含 {len(training_data)} 個樣本")
print("\n示例數據:")
for i, item in enumerate(training_data[:3]):
    print(f"\n樣本 {i+1}:")
    print(f"自然語言: {item['natural_language']}")
    print(f"SQL: {item['sql']}")
    print(f"說明: {item['explanation']}")

## 5. 向量化與索引

### 使用句子嵌入模型

In [None]:
# 初始化句子嵌入模型
class TextEmbedder:
    def __init__(self, model_name='all-MiniLM-L6-v2'):
        """初始化文本嵌入器"""
        try:
            self.model = SentenceTransformer(model_name)
            print(f"成功加載模型: {model_name}")
        except Exception as e:
            print(f"模型加載失敗: {e}, 使用備用方案")
            # 備用方案：使用簡單的詞向量
            self.model = None
    
    def encode(self, texts: List[str]) -> np.ndarray:
        """將文本編碼為向量"""
        if self.model:
            return self.model.encode(texts)
        else:
            # 簡單的備用編碼方案
            return self._simple_encode(texts)
    
    def _simple_encode(self, texts: List[str]) -> np.ndarray:
        """簡單的文本編碼備用方案"""
        # 這是一個簡化的實現，實際應用中應使用更複雜的方法
        vocab = {}
        for text in texts:
            for word in text.lower().split():
                if word not in vocab:
                    vocab[word] = len(vocab)
        
        vectors = []
        for text in texts:
            vector = np.zeros(len(vocab))
            for word in text.lower().split():
                if word in vocab:
                    vector[vocab[word]] = 1
            vectors.append(vector)
        
        return np.array(vectors)

# 初始化嵌入器
embedder = TextEmbedder()

# 測試嵌入功能
test_texts = [
    "查詢所有用戶",
    "顯示產品信息",
    "統計訂單數量"
]

embeddings = embedder.encode(test_texts)
print(f"文本嵌入維度: {embeddings.shape}")
print(f"嵌入示例: {embeddings[0][:5]}...")  # 顯示前5個維度

### 構建向量索引

In [None]:
# 構建FAISS向量索引
class VectorIndex:
    def __init__(self, dimension: int):
        """初始化向量索引"""
        self.dimension = dimension
        self.index = faiss.IndexFlatL2(dimension)  # L2距離索引
        self.texts = []  # 存儲原始文本
        self.metadata = []  # 存儲元數據
    
    def add_vectors(self, vectors: np.ndarray, texts: List[str], metadata: List[Dict]):
        """添加向量到索引"""
        self.index.add(vectors.astype('float32'))
        self.texts.extend(texts)
        self.metadata.extend(metadata)
        print(f"已添加 {len(vectors)} 個向量到索引")
    
    def search(self, query_vector: np.ndarray, k: int = 5) -> List[Dict]:
        """搜索最相似的向量"""
        query_vector = query_vector.reshape(1, -1).astype('float32')
        distances, indices = self.index.search(query_vector, k)
        
        results = []
        for i, (distance, idx) in enumerate(zip(distances[0], indices[0])):
            if idx < len(self.texts):  # 確保索引有效
                results.append({
                    'text': self.texts[idx],
                    'metadata': self.metadata[idx],
                    'distance': float(distance),
                    'rank': i + 1
                })
        return results
    
    def get_stats(self) -> Dict:
        """獲取索引統計信息"""
        return {
            'total_vectors': self.index.ntotal,
            'dimension': self.dimension,
            'index_type': 'FlatL2'
        }

# 為訓練數據創建向量索引
natural_language_queries = [item['natural_language'] for item in training_data]
query_embeddings = embedder.encode(natural_language_queries)

# 初始化向量索引
vector_index = VectorIndex(query_embeddings.shape[1])

# 添加向量到索引
vector_index.add_vectors(
    query_embeddings,
    natural_language_queries,
    training_data
)

print("向量索引統計:")
print(vector_index.get_stats())

## 6. 檢索系統

### 實現語義檢索

In [None]:
# 實現檢索系統
class SemanticRetriever:
    def __init__(self, embedder: TextEmbedder, vector_index: VectorIndex, schema: DatabaseSchema):
        """初始化語義檢索器"""
        self.embedder = embedder
        self.vector_index = vector_index
        self.schema = schema
    
    def retrieve(self, query: str, k: int = 3) -> List[Dict]:
        """檢索相關的SQL示例"""
        # 1. 將查詢編碼為向量
        query_vector = self.embedder.encode([query])[0]
        
        # 2. 在向量索引中搜索
        similar_examples = self.vector_index.search(query_vector, k)
        
        # 3. 添加數據庫模式信息
        schema_info = self.schema.get_schema_info()
        
        # 4. 構建檢索結果
        retrieval_results = {
            'query': query,
            'schema_info': schema_info,
            'similar_examples': similar_examples,
            'retrieval_score': self._calculate_retrieval_score(similar_examples)
        }
        
        return retrieval_results
    
    def _calculate_retrieval_score(self, examples: List[Dict]) -> float:
        """計算檢索質量分數"""
        if not examples:
            return 0.0
        
        # 基於距離計算分數（距離越小，分數越高）
        distances = [ex['distance'] for ex in examples]
        avg_distance = sum(distances) / len(distances)
        score = max(0, 1 - avg_distance / 10)  # 簡單的評分公式
        return score
    
    def explain_retrieval(self, results: Dict) -> str:
        """解釋檢索過程"""
        explanation = f"\n檢索結果解釋：\n"
        explanation += f"查詢: {results['query']}\n"
        explanation += f"檢索分數: {results['retrieval_score']:.3f}\n"
        explanation += f"找到 {len(results['similar_examples'])} 個相似示例:\n"
        
        for i, example in enumerate(results['similar_examples']):
            explanation += f"\n{i+1}. 相似度: {1-example['distance']:.3f}\n"
            explanation += f"   查詢: {example['text']}\n"
            explanation += f"   SQL: {example['metadata']['sql']}\n"
        
        return explanation

# 初始化檢索器
retriever = SemanticRetriever(embedder, vector_index, schema)

# 測試檢索功能
test_query = "我要查看所有產品的詳細信息"
retrieval_results = retriever.retrieve(test_query)

print("=== 檢索測試 ===")
print(retriever.explain_retrieval(retrieval_results))

## 7. SQL 生成

### 基於檢索結果生成SQL

In [None]:
# SQL生成器
class SQLGenerator:
    def __init__(self, schema: DatabaseSchema):
        """初始化SQL生成器"""
        self.schema = schema
        self.templates = self._load_sql_templates()
    
    def _load_sql_templates(self) -> Dict[str, str]:
        """加載SQL模板"""
        return {
            'select_all': "SELECT * FROM {table};",
            'select_where': "SELECT * FROM {table} WHERE {condition};",
            'count': "SELECT COUNT(*) FROM {table};",
            'join': "SELECT {columns} FROM {table1} t1 JOIN {table2} t2 ON {join_condition};",
            'group_by': "SELECT {columns}, {aggregate} FROM {table} GROUP BY {group_columns};",
            'order_by': "SELECT * FROM {table} ORDER BY {column} {direction};"
        }
    
    def generate_sql(self, query: str, retrieval_results: Dict) -> Dict:
        """基於檢索結果生成SQL"""
        # 1. 分析查詢意圖
        intent = self._analyze_intent(query)
        
        # 2. 從檢索結果中提取最佳匹配
        best_match = retrieval_results['similar_examples'][0] if retrieval_results['similar_examples'] else None
        
        # 3. 生成SQL
        if best_match and best_match['distance'] < 0.5:  # 高相似度
            # 使用檢索到的SQL作為基礎
            generated_sql = self._adapt_sql(best_match['metadata']['sql'], query)
            method = 'retrieval_based'
        else:
            # 使用模板生成
            generated_sql = self._template_based_generation(query, intent)
            method = 'template_based'
        
        # 4. 驗證SQL
        is_valid, validation_error = self._validate_sql(generated_sql)
        
        return {
            'query': query,
            'generated_sql': generated_sql,
            'method': method,
            'intent': intent,
            'is_valid': is_valid,
            'validation_error': validation_error,
            'best_match': best_match
        }
    
    def _analyze_intent(self, query: str) -> Dict:
        """分析查詢意圖"""
        query_lower = query.lower()
        
        intent = {
            'action': 'select',  # 默認為查詢
            'tables': [],
            'conditions': [],
            'aggregation': None,
            'sorting': None
        }
        
        # 識別表名
        for table in self.schema.tables.keys():
            if table in query_lower or self._table_synonyms(table, query_lower):
                intent['tables'].append(table)
        
        # 識別聚合操作
        if any(word in query_lower for word in ['統計', '計算', '總和', '平均', '最大', '最小']):
            intent['aggregation'] = 'count'  # 簡化處理
        
        # 識別排序
        if any(word in query_lower for word in ['最', '排序', '最高', '最低']):
            intent['sorting'] = 'desc' if '最高' in query_lower else 'asc'
        
        return intent
    
    def _table_synonyms(self, table: str, query: str) -> bool:
        """檢查表名同義詞"""
        synonyms = {
            'users': ['用戶', '使用者', '會員'],
            'products': ['產品', '商品', '物品'],
            'orders': ['訂單', '訂購', '購買']
        }
        
        if table in synonyms:
            return any(syn in query for syn in synonyms[table])
        return False
    
    def _adapt_sql(self, base_sql: str, query: str) -> str:
        """調整基礎SQL以匹配新查詢"""
        # 這裡可以實現更複雜的SQL調整邏輯
        # 目前只是簡單返回基礎SQL
        return base_sql
    
    def _template_based_generation(self, query: str, intent: Dict) -> str:
        """基於模板生成SQL"""
        if not intent['tables']:
            return "SELECT 'No table identified' as error;"
        
        table = intent['tables'][0]  # 使用第一個識別的表
        
        if intent['aggregation']:
            return f"SELECT COUNT(*) FROM {table};"
        elif intent['sorting']:
            # 假設按第一個數值列排序
            numeric_columns = self._get_numeric_columns(table)
            if numeric_columns:
                column = numeric_columns[0]
                direction = intent['sorting'].upper()
                return f"SELECT * FROM {table} ORDER BY {column} {direction};"
        
        return f"SELECT * FROM {table};"
    
    def _get_numeric_columns(self, table: str) -> List[str]:
        """獲取表中的數值列"""
        if table not in self.schema.tables:
            return []
        
        numeric_types = ['INT', 'DECIMAL', 'FLOAT', 'DOUBLE']
        numeric_columns = []
        
        for column, column_type in self.schema.tables[table].items():
            if any(num_type in column_type.upper() for num_type in numeric_types):
                numeric_columns.append(column)
        
        return numeric_columns
    
    def _validate_sql(self, sql: str) -> Tuple[bool, str]:
        """驗證SQL語法"""
        try:
            parsed = sqlparse.parse(sql)[0]
            if parsed.tokens:
                return True, ""
            else:
                return False, "Empty SQL statement"
        except Exception as e:
            return False, str(e)

# 初始化SQL生成器
sql_generator = SQLGenerator(schema)

# 測試SQL生成
test_queries = [
    "顯示所有產品信息",
    "找出最貴的產品",
    "統計用戶數量"
]

print("=== SQL生成測試 ===")
for query in test_queries:
    retrieval_results = retriever.retrieve(query)
    generation_results = sql_generator.generate_sql(query, retrieval_results)
    
    print(f"\n查詢: {query}")
    print(f"生成的SQL: {generation_results['generated_sql']}")
    print(f"生成方法: {generation_results['method']}")
    print(f"SQL有效性: {generation_results['is_valid']}")
    if generation_results['validation_error']:
        print(f"驗證錯誤: {generation_results['validation_error']}")

## 8. 完整流水線

### 整合所有組件

In [None]:
# 完整的TextSQL RAG流水線
class TextSQLRAGPipeline:
    def __init__(self, 
                 embedder: TextEmbedder,
                 vector_index: VectorIndex,
                 schema: DatabaseSchema,
                 retriever: SemanticRetriever,
                 sql_generator: SQLGenerator):
        """初始化完整的RAG流水線"""
        self.embedder = embedder
        self.vector_index = vector_index
        self.schema = schema
        self.retriever = retriever
        self.sql_generator = sql_generator
        self.query_history = []
    
    def process_query(self, natural_language_query: str, execute_sql: bool = False) -> Dict:
        """處理自然語言查詢的完整流程"""
        print(f"\n處理查詢: {natural_language_query}")
        
        # 第1步：檢索相關示例
        print("第1步：檢索相關示例...")
        retrieval_results = self.retriever.retrieve(natural_language_query)
        
        # 第2步：生成SQL
        print("第2步：生成SQL...")
        generation_results = self.sql_generator.generate_sql(natural_language_query, retrieval_results)
        
        # 第3步：執行SQL（可選）
        execution_results = None
        if execute_sql and generation_results['is_valid']:
            print("第3步：執行SQL...")
            execution_results = self._execute_sql(generation_results['generated_sql'])
        
        # 第4步：構建完整結果
        complete_results = {
            'natural_language_query': natural_language_query,
            'retrieval_results': retrieval_results,
            'generation_results': generation_results,
            'execution_results': execution_results,
            'pipeline_success': generation_results['is_valid'],
            'timestamp': pd.Timestamp.now().isoformat()
        }
        
        # 第5步：記錄查詢歷史
        self.query_history.append(complete_results)
        
        return complete_results
    
    def _execute_sql(self, sql: str) -> Dict:
        """執行SQL查詢"""
        try:
            conn = sqlite3.connect('sample_ecommerce.db')
            
            # 執行查詢
            result_df = pd.read_sql_query(sql, conn)
            
            conn.close()
            
            return {
                'success': True,
                'data': result_df.to_dict('records'),
                'row_count': len(result_df),
                'columns': list(result_df.columns),
                'error': None
            }
        
        except Exception as e:
            return {
                'success': False,
                'data': None,
                'row_count': 0,
                'columns': [],
                'error': str(e)
            }
    
    def explain_results(self, results: Dict) -> str:
        """解釋處理結果"""
        explanation = f"\n{'='*50}\n"
        explanation += f"查詢: {results['natural_language_query']}\n"
        explanation += f"{'='*50}\n"
        
        # 檢索階段
        explanation += "\n🔍 檢索階段:\n"
        retrieval = results['retrieval_results']
        explanation += f"找到 {len(retrieval['similar_examples'])} 個相似示例\n"
        explanation += f"檢索分數: {retrieval['retrieval_score']:.3f}\n"
        
        if retrieval['similar_examples']:
            best_match = retrieval['similar_examples'][0]
            explanation += f"最佳匹配: {best_match['text']} (相似度: {1-best_match['distance']:.3f})\n"
        
        # 生成階段
        explanation += "\n⚙️ 生成階段:\n"
        generation = results['generation_results']
        explanation += f"生成方法: {generation['method']}\n"
        explanation += f"生成的SQL: {generation['generated_sql']}\n"
        explanation += f"SQL有效性: {generation['is_valid']}\n"
        
        # 執行階段
        if results['execution_results']:
            explanation += "\n🚀 執行階段:\n"
            execution = results['execution_results']
            if execution['success']:
                explanation += f"執行成功，返回 {execution['row_count']} 行數據\n"
                if execution['data']:
                    explanation += "前幾行數據:\n"
                    for i, row in enumerate(execution['data'][:3]):
                        explanation += f"  {i+1}: {row}\n"
            else:
                explanation += f"執行失敗: {execution['error']}\n"
        
        explanation += f"\n✅ 流水線成功: {results['pipeline_success']}\n"
        
        return explanation
    
    def get_pipeline_stats(self) -> Dict:
        """獲取流水線統計信息"""
        if not self.query_history:
            return {'total_queries': 0}
        
        successful_queries = sum(1 for q in self.query_history if q['pipeline_success'])
        
        return {
            'total_queries': len(self.query_history),
            'successful_queries': successful_queries,
            'success_rate': successful_queries / len(self.query_history),
            'avg_retrieval_score': np.mean([q['retrieval_results']['retrieval_score'] 
                                          for q in self.query_history])
        }

# 初始化完整流水線
pipeline = TextSQLRAGPipeline(
    embedder=embedder,
    vector_index=vector_index,
    schema=schema,
    retriever=retriever,
    sql_generator=sql_generator
)

print("✅ TextSQL RAG 流水線初始化完成！")

### 流水線演示

In [None]:
# 流水線演示
demo_queries = [
    "顯示所有用戶的基本信息",
    "查找價格最高的產品",
    "統計每個類別有多少產品",
    "顯示張三購買的所有商品",
    "找出還沒有下過訂單的用戶"
]

print("🚀 開始流水線演示\n")

for i, query in enumerate(demo_queries, 1):
    print(f"\n{'='*60}")
    print(f"演示 {i}/{len(demo_queries)}")
    print(f"{'='*60}")
    
    # 處理查詢
    results = pipeline.process_query(query, execute_sql=True)
    
    # 顯示結果
    print(pipeline.explain_results(results))
    
    # 暫停一下讓輸出更清晰
    import time
    time.sleep(1)

# 顯示整體統計
print("\n" + "="*60)
print("流水線統計")
print("="*60)
stats = pipeline.get_pipeline_stats()
for key, value in stats.items():
    print(f"{key}: {value}")

## 9. 評估與優化

### 評估指標

In [None]:
# 評估系統
class RAGEvaluator:
    def __init__(self, pipeline: TextSQLRAGPipeline):
        """初始化評估器"""
        self.pipeline = pipeline
    
    def evaluate_retrieval(self, test_cases: List[Dict]) -> Dict:
        """評估檢索性能"""
        retrieval_scores = []
        precision_scores = []
        
        for test_case in test_cases:
            query = test_case['natural_language']
            expected_sql = test_case['sql']
            
            # 檢索相關示例
            retrieval_results = self.pipeline.retriever.retrieve(query)
            
            # 計算檢索分數
            retrieval_scores.append(retrieval_results['retrieval_score'])
            
            # 計算精確度（檢索結果中是否包含正確答案）
            precision = self._calculate_precision(retrieval_results, expected_sql)
            precision_scores.append(precision)
        
        return {
            'avg_retrieval_score': np.mean(retrieval_scores),
            'avg_precision': np.mean(precision_scores),
            'retrieval_scores': retrieval_scores,
            'precision_scores': precision_scores
        }
    
    def evaluate_generation(self, test_cases: List[Dict]) -> Dict:
        """評估SQL生成性能"""
        exact_match_scores = []
        syntax_valid_scores = []
        semantic_similarity_scores = []
        
        for test_case in test_cases:
            query = test_case['natural_language']
            expected_sql = test_case['sql']
            
            # 生成SQL
            results = self.pipeline.process_query(query, execute_sql=False)
            generated_sql = results['generation_results']['generated_sql']
            
            # 精確匹配
            exact_match = self._normalize_sql(generated_sql) == self._normalize_sql(expected_sql)
            exact_match_scores.append(exact_match)
            
            # 語法有效性
            syntax_valid = results['generation_results']['is_valid']
            syntax_valid_scores.append(syntax_valid)
            
            # 語義相似度（簡化實現）
            semantic_sim = self._calculate_semantic_similarity(generated_sql, expected_sql)
            semantic_similarity_scores.append(semantic_sim)
        
        return {
            'exact_match_rate': np.mean(exact_match_scores),
            'syntax_valid_rate': np.mean(syntax_valid_scores),
            'avg_semantic_similarity': np.mean(semantic_similarity_scores),
            'exact_match_scores': exact_match_scores,
            'syntax_valid_scores': syntax_valid_scores,
            'semantic_similarity_scores': semantic_similarity_scores
        }
    
    def evaluate_end_to_end(self, test_cases: List[Dict]) -> Dict:
        """端到端評估"""
        execution_success_scores = []
        result_accuracy_scores = []
        
        for test_case in test_cases:
            query = test_case['natural_language']
            expected_sql = test_case['sql']
            
            # 執行完整流水線
            results = self.pipeline.process_query(query, execute_sql=True)
            
            # 執行成功率
            execution_success = (
                results['execution_results'] is not None and 
                results['execution_results']['success']
            )
            execution_success_scores.append(execution_success)
            
            # 結果准確性（通過執行期望SQL進行比較）
            if execution_success:
                result_accuracy = self._compare_execution_results(
                    results['generation_results']['generated_sql'],
                    expected_sql
                )
                result_accuracy_scores.append(result_accuracy)
            else:
                result_accuracy_scores.append(0.0)
        
        return {
            'execution_success_rate': np.mean(execution_success_scores),
            'result_accuracy_rate': np.mean(result_accuracy_scores),
            'execution_success_scores': execution_success_scores,
            'result_accuracy_scores': result_accuracy_scores
        }
    
    def _calculate_precision(self, retrieval_results: Dict, expected_sql: str) -> float:
        """計算檢索精確度"""
        similar_examples = retrieval_results['similar_examples']
        if not similar_examples:
            return 0.0
        
        # 檢查是否有檢索結果與期望SQL相似
        for example in similar_examples:
            if self._sql_similarity(example['metadata']['sql'], expected_sql) > 0.8:
                return 1.0
        
        return 0.0
    
    def _normalize_sql(self, sql: str) -> str:
        """標準化SQL字符串"""
        # 移除多餘空格，轉換為小寫
        return ' '.join(sql.lower().split())
    
    def _calculate_semantic_similarity(self, sql1: str, sql2: str) -> float:
        """計算SQL語義相似度"""
        # 簡化實現：基於關鍵詞重疊
        words1 = set(self._normalize_sql(sql1).split())
        words2 = set(self._normalize_sql(sql2).split())
        
        if not words1 and not words2:
            return 1.0
        if not words1 or not words2:
            return 0.0
        
        intersection = words1.intersection(words2)
        union = words1.union(words2)
        
        return len(intersection) / len(union)
    
    def _sql_similarity(self, sql1: str, sql2: str) -> float:
        """計算SQL相似度"""
        return self._calculate_semantic_similarity(sql1, sql2)
    
    def _compare_execution_results(self, generated_sql: str, expected_sql: str) -> float:
        """比較執行結果"""
        try:
            conn = sqlite3.connect('sample_ecommerce.db')
            
            # 執行兩個SQL
            result1 = pd.read_sql_query(generated_sql, conn)
            result2 = pd.read_sql_query(expected_sql, conn)
            
            conn.close()
            
            # 比較結果
            if result1.equals(result2):
                return 1.0
            elif len(result1) == len(result2) and len(result1.columns) == len(result2.columns):
                return 0.5  # 部分匹配
            else:
                return 0.0
                
        except Exception:
            return 0.0
    
    def generate_evaluation_report(self, test_cases: List[Dict]) -> str:
        """生成評估報告"""
        retrieval_eval = self.evaluate_retrieval(test_cases)
        generation_eval = self.evaluate_generation(test_cases)
        e2e_eval = self.evaluate_end_to_end(test_cases)
        
        report = "\n" + "="*50 + "\n"
        report += "TextSQL RAG Pipeline 評估報告\n"
        report += "="*50 + "\n"
        
        report += "\n📊 檢索性能:\n"
        report += f"  平均檢索分數: {retrieval_eval['avg_retrieval_score']:.3f}\n"
        report += f"  平均精確度: {retrieval_eval['avg_precision']:.3f}\n"
        
        report += "\n🔧 生成性能:\n"
        report += f"  精確匹配率: {generation_eval['exact_match_rate']:.3f}\n"
        report += f"  語法有效率: {generation_eval['syntax_valid_rate']:.3f}\n"
        report += f"  平均語義相似度: {generation_eval['avg_semantic_similarity']:.3f}\n"
        
        report += "\n🚀 端到端性能:\n"
        report += f"  執行成功率: {e2e_eval['execution_success_rate']:.3f}\n"
        report += f"  結果准確率: {e2e_eval['result_accuracy_rate']:.3f}\n"
        
        # 總體評分
        overall_score = np.mean([
            retrieval_eval['avg_retrieval_score'],
            generation_eval['syntax_valid_rate'],
            e2e_eval['execution_success_rate']
        ])
        
        report += f"\n⭐ 總體評分: {overall_score:.3f}\n"
        
        return report

# 創建評估器
evaluator = RAGEvaluator(pipeline)

# 運行評估
print("🔍 開始評估流水線性能...")
evaluation_report = evaluator.generate_evaluation_report(training_data)
print(evaluation_report)

## 10. 實際應用案例

### 互動式查詢界面

In [None]:
# 互動式查詢界面
class InteractiveQueryInterface:
    def __init__(self, pipeline: TextSQLRAGPipeline):
        """初始化互動式界面"""
        self.pipeline = pipeline
        self.session_history = []
    
    def start_session(self):
        """開始互動會話"""
        print("\n🎯 歡迎使用 TextSQL RAG 查詢系統！")
        print("輸入自然語言查詢，系統將生成對應的SQL語句並執行")
        print("輸入 'help' 查看幫助，輸入 'quit' 退出系統\n")
        
        while True:
            try:
                user_input = input("🔍 請輸入您的查詢: ").strip()
                
                if user_input.lower() == 'quit':
                    print("👋 感謝使用，再見！")
                    break
                elif user_input.lower() == 'help':
                    self._show_help()
                    continue
                elif user_input.lower() == 'history':
                    self._show_history()
                    continue
                elif user_input.lower() == 'schema':
                    self._show_schema()
                    continue
                elif not user_input:
                    continue
                
                # 處理查詢
                self._process_interactive_query(user_input)
                
            except KeyboardInterrupt:
                print("\n👋 用戶中斷，再見！")
                break
            except Exception as e:
                print(f"❌ 發生錯誤: {e}")
    
    def _process_interactive_query(self, query: str):
        """處理互動查詢"""
        print(f"\n⏳ 處理中...")
        
        # 執行查詢
        results = self.pipeline.process_query(query, execute_sql=True)
        
        # 記錄會話歷史
        self.session_history.append(results)
        
        # 顯示結果
        self._display_results(results)
    
    def _display_results(self, results: Dict):
        """顯示查詢結果"""
        print("\n" + "="*50)
        print(f"📝 查詢: {results['natural_language_query']}")
        print("="*50)
        
        # 顯示生成的SQL
        generation = results['generation_results']
        print(f"\n💻 生成的SQL:")
        print(f"```sql\n{generation['generated_sql']}\n```")
        
        print(f"\n🔧 生成方法: {generation['method']}")
        print(f"✅ SQL有效性: {generation['is_valid']}")
        
        # 顯示執行結果
        if results['execution_results']:
            execution = results['execution_results']
            if execution['success']:
                print(f"\n🎯 執行結果: 成功返回 {execution['row_count']} 行數據")
                
                if execution['data']:
                    # 將結果格式化為表格
                    df = pd.DataFrame(execution['data'])
                    print("\n📊 查詢結果:")
                    print(df.to_string(index=False))
                else:
                    print("\n📊 查詢結果: 無數據返回")
            else:
                print(f"\n❌ 執行失敗: {execution['error']}")
        
        # 顯示檢索信息
        retrieval = results['retrieval_results']
        if retrieval['similar_examples']:
            best_match = retrieval['similar_examples'][0]
            print(f"\n🔍 最佳匹配示例: {best_match['text']}")
            print(f"📈 相似度: {1-best_match['distance']:.3f}")
    
    def _show_help(self):
        """顯示幫助信息"""
        help_text = """
📚 使用幫助:

基本查詢示例:
  • "顯示所有用戶" - 查詢用戶表
  • "找出最貴的產品" - 按價格排序
  • "統計每個類別的產品數量" - 聚合查詢
  • "顯示張三的所有訂單" - 聯合查詢

特殊命令:
  • help - 顯示此幫助
  • history - 顯示查詢歷史
  • schema - 顯示數據庫結構
  • quit - 退出系統

💡 提示: 盡量使用自然語言描述您的查詢需求！
"""
        print(help_text)
    
    def _show_history(self):
        """顯示查詢歷史"""
        if not self.session_history:
            print("\n📝 暫無查詢歷史")
            return
        
        print(f"\n📝 查詢歷史 (共 {len(self.session_history)} 條):")
        print("-" * 60)
        
        for i, result in enumerate(self.session_history[-5:], 1):  # 只顯示最近5條
            query = result['natural_language_query']
            sql = result['generation_results']['generated_sql']
            success = result['pipeline_success']
            status = "✅" if success else "❌"
            
            print(f"{i}. {status} {query}")
            print(f"   SQL: {sql[:50]}{'...' if len(sql) > 50 else ''}")
            print()
    
    def _show_schema(self):
        """顯示數據庫結構"""
        print("\n🗃️ 數據庫結構:")
        print(self.pipeline.schema.get_schema_info())

# 創建互動界面
interface = InteractiveQueryInterface(pipeline)

# 演示一些自動查詢（而不是真正的互動模式）
print("\n🎯 TextSQL RAG 系統演示")
print("以下是一些自動執行的查詢示例:\n")

demo_queries = [
    "查看所有用戶的信息",
    "找出價格最高的三個產品",
    "統計每個用戶的訂單總數"
]

for query in demo_queries:
    print(f"🔍 查詢: {query}")
    interface._process_interactive_query(query)
    print("\n" + "-"*60 + "\n")

print("\n✨ 演示完成！要開始真正的互動模式，請取消註釋下面的代碼：")
print("# interface.start_session()")

## 總結與優化建議

### 學習總結

通過本筆記本，我們完整實現了一個 TextSQL RAG 流水線，包括：

1. **數據預處理**: 創建示例數據庫和訓練數據
2. **向量化**: 使用句子嵌入模型將文本轉換為向量
3. **檢索系統**: 基於語義相似度檢索相關示例
4. **SQL生成**: 結合檢索結果和模板生成SQL
5. **完整流水線**: 整合所有組件
6. **評估系統**: 多維度評估性能
7. **實際應用**: 互動式查詢界面

### 優化建議

1. **模型優化**:
   - 使用更強大的嵌入模型（如 OpenAI embeddings）
   - 實現微調機制以適應特定領域

2. **檢索優化**:
   - 實現混合檢索（語義+關鍵詞）
   - 添加重排序機制

3. **生成優化**:
   - 集成大語言模型（如 GPT-4）
   - 實現更複雜的SQL模板

4. **系統優化**:
   - 添加緩存機制
   - 實現分佈式部署
   - 增強錯誤處理

### 後續學習方向

1. 深入學習 Transformer 架構
2. 探索更高級的 RAG 技術
3. 學習數據庫優化技術
4. 研究多模態 RAG 系統

In [None]:
# 保存流水線狀態和結果
import pickle

def save_pipeline_state():
    """保存流水線狀態"""
    state = {
        'query_history': pipeline.query_history,
        'pipeline_stats': pipeline.get_pipeline_stats(),
        'training_data': training_data,
        'schema_info': schema.get_schema_info()
    }
    
    with open('pipeline_state.pkl', 'wb') as f:
        pickle.dump(state, f)
    
    # 也保存為JSON格式便於查看
    json_state = {
        'pipeline_stats': pipeline.get_pipeline_stats(),
        'training_data': training_data,
        'schema_info': schema.get_schema_info()
    }
    
    with open('pipeline_summary.json', 'w', encoding='utf-8') as f:
        json.dump(json_state, f, ensure_ascii=False, indent=2)
    
    print("✅ 流水線狀態已保存")

# 保存狀態
save_pipeline_state()

print("\n🎉 TextSQL RAG Pipeline 學習筆記完成！")
print("\n📝 學習成果:")
print(f"  • 處理了 {len(pipeline.query_history)} 個查詢")
print(f"  • 成功率: {pipeline.get_pipeline_stats().get('success_rate', 0):.1%}")
print(f"  • 平均檢索分數: {pipeline.get_pipeline_stats().get('avg_retrieval_score', 0):.3f}")
print("\n🚀 您現在可以:")
print("  1. 在 Kaggle 環境中運行此筆記本")
print("  2. 修改數據庫結構和訓練數據")
print("  3. 優化檢索和生成算法")
print("  4. 集成更強大的語言模型")
print("\n💡 建議下一步: 嘗試在真實數據集上測試此系統！")