<a href="https://colab.research.google.com/github/wangyiyang/RAG-Cookbook-Code/blob/main/ch03/query_optimizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install sentence-transformers torch transformers faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0.post1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-m

In [2]:
"""
查询理解与优化模块
实现查询分类、扩展、重写等功能
"""

import re
from typing import List, Dict, Any
from dataclasses import dataclass


@dataclass
class QueryAnalysis:
    """查询分析结果"""
    original_query: str
    query_type: str
    expanded_terms: List[str]
    optimized_query: str
    confidence: float


class QueryClassifier:
    """查询分类器"""

    def __init__(self):
        self.patterns = {
            'short_query': r'^.{1,10}$',  # 短查询（1-10字符）
            'question': r'[什么|如何|为什么|怎么|哪个|哪些]',  # 问句
            'exact_match': r'[""].*[""]',  # 精确匹配（带引号）
            'ambiguous': r'[它|这个|那个|前面|上述]'  # 模糊指代
        }

    def classify(self, query: str) -> str:
        """对查询进行分类"""
        query = query.strip()

        # 优先检查歧义和精确匹配
        if re.search(self.patterns['ambiguous'], query):
            return 'ambiguous'
        if re.search(self.patterns['exact_match'], query):
            return 'exact_match'
        if re.search(self.patterns['question'], query):
            return 'question'

        # 根据长度判断（排除问句后）
        if len(query) <= 10 and not re.search(self.patterns['question'], query):
            return 'short_query'
        elif len(query.split()) >= 15:
            return 'complex_query'
        else:
            return 'normal_query'


class QueryExpander:
    """查询扩展器"""

    def __init__(self):
        # 同义词词典（实际应用中可从外部文件加载）
        self.synonyms = {
            'RAG': ['检索增强生成', '检索增强', '信息检索'],
            'AI': ['人工智能', '机器学习', '深度学习'],
            '性能': ['效率', '速度', '响应时间'],
            '优化': ['改进', '提升', '增强'],
            '检索增强生成': ['RAG', '向量检索', '语义搜索'],
            '检索算法': ['BM25', 'TF-IDF', '语义相似度', '混合检索']
        }

        # 概念扩展词典
        self.concept_expansions = {
            'RAG技术': ['向量检索', '文档分块', '语义搜索', 'LLM生成'],
            '检索算法': ['BM25', 'TF-IDF', '语义相似度', '混合检索'],
            '重排序': ['Cross-Encoder', 'ColBERT', 'Learning-to-Rank']
        }

    def expand(self, query: str) -> List[str]:
        """扩展查询词汇"""
        expanded_terms = []

        # 检查整个查询中的关键词
        for key, synonyms in self.synonyms.items():
            if key in query:
                expanded_terms.extend(synonyms)

        for key, concepts in self.concept_expansions.items():
            if key in query:
                expanded_terms.extend(concepts)

        # 也检查单个词汇
        words = query.split()
        for word in words:
            if word in self.synonyms:
                expanded_terms.extend(self.synonyms[word])
            if word in self.concept_expansions:
                expanded_terms.extend(self.concept_expansions[word])

        return list(set(expanded_terms))  # 去重，不包含原查询


class QueryRewriter:
    """查询重写器"""

    def __init__(self):
        self.rewrite_rules = {
            # 疑问词规范化
            r'怎么样': '如何',
            r'咋样': '如何',
            r'啥': '什么',
            # 术语规范化
            r'ai': 'AI',
            r'rag': 'RAG',
            r'llm': 'LLM'
        }

    def rewrite(self, query: str) -> str:
        """重写查询以提升检索效果"""
        rewritten = query

        # 应用重写规则
        for pattern, replacement in self.rewrite_rules.items():
            rewritten = re.sub(pattern, replacement, rewritten, flags=re.IGNORECASE)

        # 移除停用词（保留问句结构）
        if not re.search(r'[什么|如何|为什么|怎么|哪个|哪些]', rewritten):
            stop_words = ['的', '了', '在', '是', '有', '和', '与']
            words = rewritten.split()
            filtered_words = [w for w in words if w not in stop_words]
            rewritten = ' '.join(filtered_words)

        return rewritten


class QueryOptimizer:
    """查询优化器主类"""

    def __init__(self):
        self.query_classifier = QueryClassifier()
        self.query_expander = QueryExpander()
        self.query_rewriter = QueryRewriter()

    def optimize(self, query: str) -> QueryAnalysis:
        """优化查询的主入口"""
        # 1. 查询分类
        query_type = self.query_classifier.classify(query)

        # 2. 查询扩展
        expanded_terms = self.query_expander.expand(query)

        # 3. 查询重写
        optimized_query = self.query_rewriter.rewrite(query)

        # 4. 针对性优化
        if query_type == "short_query":
            # 短查询添加更多扩展词
            optimized_query = self._enhance_short_query(optimized_query, expanded_terms)
        elif query_type == "ambiguous":
            # 歧义查询消解
            optimized_query = self._disambiguate_query(optimized_query)

        # 5. 计算优化置信度
        confidence = self._calculate_confidence(query, optimized_query, query_type)

        return QueryAnalysis(
            original_query=query,
            query_type=query_type,
            expanded_terms=expanded_terms,
            optimized_query=optimized_query,
            confidence=confidence
        )

    def _enhance_short_query(self, query: str, expanded_terms: List[str]) -> str:
        """增强短查询"""
        # 添加最相关的扩展词
        relevant_terms = expanded_terms[:3]  # 取前3个
        if relevant_terms:
            enhanced = f"{query} {' '.join(relevant_terms)}"
        else:
            enhanced = query
        return enhanced

    def _disambiguate_query(self, query: str) -> str:
        """消除查询歧义"""
        # 简单的歧义消解策略
        if '它' in query:
            return query.replace('它', 'RAG技术')
        if '这个' in query:
            return query.replace('这个', '该方法')
        return query

    def _calculate_confidence(self, original: str, optimized: str, query_type: str) -> float:
        """计算优化置信度"""
        if original == optimized:
            return 0.5  # 无优化

        confidence_map = {
            'short_query': 0.8,
            'ambiguous': 0.9,
            'question': 0.7,
            'normal_query': 0.7,
            'complex_query': 0.6
        }

        return confidence_map.get(query_type, 0.7)


# 使用示例
if __name__ == "__main__":
    optimizer = QueryOptimizer()

    test_queries = [
        "RAG",
        "什么是检索增强生成？",
        "它的性能怎么样？",
        "如何优化检索算法的效率和准确性"
    ]

    for query in test_queries:
        analysis = optimizer.optimize(query)
        print(f"原查询: {analysis.original_query}")
        print(f"查询类型: {analysis.query_type}")
        print(f"扩展词汇: {analysis.expanded_terms}")
        print(f"优化后: {analysis.optimized_query}")
        print(f"置信度: {analysis.confidence:.2f}")
        print("-" * 50)

原查询: RAG
查询类型: short_query
扩展词汇: ['信息检索', '检索增强', '检索增强生成']
优化后: RAG 信息检索 检索增强 检索增强生成
置信度: 0.80
--------------------------------------------------
原查询: 什么是检索增强生成？
查询类型: question
扩展词汇: ['语义搜索', 'RAG', '向量检索']
优化后: 什么是检索增强生成？
置信度: 0.50
--------------------------------------------------
原查询: 它的性能怎么样？
查询类型: ambiguous
扩展词汇: ['速度', '响应时间', '效率']
优化后: RAG技术的性能如何？
置信度: 0.90
--------------------------------------------------
原查询: 如何优化检索算法的效率和准确性
查询类型: question
扩展词汇: ['语义相似度', '混合检索', '改进', '增强', 'TF-IDF', 'BM25', '提升']
优化后: 如何优化检索算法的效率和准确性
置信度: 0.50
--------------------------------------------------
