<a href="https://colab.research.google.com/github/wangyiyang/RAG-Cookbook-Code/blob/main/ch03/hybrid_retriever.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install sentence-transformers torch transformers faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0.post1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-m

In [4]:
"""
混合检索策略实现
结合关键词检索和语义检索的优势
"""

import numpy as np
from typing import List, Dict, Tuple, Any
from dataclasses import dataclass
from collections import defaultdict
import math


@dataclass
class SearchResult:
    """检索结果"""
    doc_id: str
    content: str
    score: float
    metadata: Dict[str, Any]
    source: str  # 'semantic', 'keyword', 'hybrid'


class BM25Retriever:
    """BM25关键词检索器"""

    def __init__(self, k1: float = 1.5, b: float = 0.75):
        self.k1 = k1
        self.b = b
        self.documents = []
        self.doc_freqs = []
        self.idf_cache = {}
        self.avgdl = 0

    def fit(self, documents: List[str]):
        """构建BM25索引"""
        self.documents = documents
        self.doc_freqs = []

        # 计算词频和文档长度
        total_length = 0
        word_doc_count = defaultdict(int)

        for doc in documents:
            words = doc.lower().split()
            total_length += len(words)

            word_freq = defaultdict(int)
            unique_words = set()

            for word in words:
                word_freq[word] += 1
                unique_words.add(word)

            # 计算每个唯一词的文档频率
            for word in unique_words:
                word_doc_count[word] += 1

            self.doc_freqs.append(word_freq)

        self.avgdl = total_length / len(documents)

        # 计算IDF
        N = len(documents)
        for word, df in word_doc_count.items():
            self.idf_cache[word] = math.log((N - df + 0.5) / (df + 0.5))

    def search(self, query: str, top_k: int = 10) -> List[SearchResult]:
        """BM25检索"""
        query_words = query.lower().split()
        scores = []

        for doc_idx, doc_freq in enumerate(self.doc_freqs):
            score = 0.0
            doc_len = sum(doc_freq.values())

            for word in query_words:
                if word in doc_freq:
                    tf = doc_freq[word]
                    idf = self.idf_cache.get(word, 0)

                    # BM25公式
                    numerator = tf * (self.k1 + 1)
                    denominator = tf + self.k1 * (1 - self.b + self.b * doc_len / self.avgdl)
                    score += idf * (numerator / denominator)

            scores.append((doc_idx, score))

        # 排序并返回top-k
        scores.sort(key=lambda x: x[1], reverse=True)

        results = []
        for doc_idx, score in scores[:top_k]:
            results.append(SearchResult(
                doc_id=str(doc_idx),
                content=self.documents[doc_idx],
                score=score,
                metadata={'doc_index': doc_idx},
                source='keyword'
            ))

        return results


class DenseRetriever:
    """密集检索器（模拟）"""

    def __init__(self, embedding_dim: int = 768):
        self.embedding_dim = embedding_dim
        self.document_embeddings = []
        self.documents = []

    def fit(self, documents: List[str]):
        """构建向量索引"""
        self.documents = documents
        # 模拟文档向量化（实际应用中使用预训练模型）
        self.document_embeddings = [
            self._mock_embed(doc) for doc in documents
        ]

    def _mock_embed(self, text: str) -> np.ndarray:
        """模拟文本嵌入（实际使用sentence-transformers）"""
        # 基于文本哈希生成模拟向量
        hash_val = hash(text) % (2**31)
        np.random.seed(hash_val)
        return np.random.randn(self.embedding_dim)

    def search(self, query: str, top_k: int = 10) -> List[SearchResult]:
        """向量相似度检索"""
        query_embedding = self._mock_embed(query)

        # 计算余弦相似度
        similarities = []
        for doc_idx, doc_embedding in enumerate(self.document_embeddings):
            similarity = np.dot(query_embedding, doc_embedding) / (
                np.linalg.norm(query_embedding) * np.linalg.norm(doc_embedding)
            )
            similarities.append((doc_idx, similarity))

        # 排序并返回top-k
        similarities.sort(key=lambda x: x[1], reverse=True)

        results = []
        for doc_idx, similarity in similarities[:top_k]:
            results.append(SearchResult(
                doc_id=str(doc_idx),
                content=self.documents[doc_idx],
                score=similarity,
                metadata={'doc_index': doc_idx},
                source='semantic'
            ))

        return results


class AdaptiveWeightManager:
    """自适应权重管理器"""

    def __init__(self):
        self.query_patterns = {
            'exact_match': 0.3,      # 精确匹配查询，偏向关键词检索
            'conceptual': 0.8,       # 概念性查询，偏向语义检索
            'mixed': 0.5,            # 混合查询，平衡权重
            'technical': 0.6         # 技术查询，稍偏向语义检索
        }

    def get_optimal_weight(self, query: str) -> float:
        """根据查询特征动态调整权重"""
        query_type = self._classify_query_pattern(query)
        return self.query_patterns.get(query_type, 0.7)

    def _classify_query_pattern(self, query: str) -> str:
        """分类查询模式"""
        query = query.lower()

        # 精确匹配模式
        if '"' in query or '精确' in query or '具体' in query:
            return 'exact_match'

        # 概念性查询
        concept_keywords = ['原理', '概念', '理论', '什么是', '如何理解']
        if any(keyword in query for keyword in concept_keywords):
            return 'conceptual'

        # 技术查询
        tech_keywords = ['实现', '算法', '代码', '技术', '方法']
        if any(keyword in query for keyword in tech_keywords):
            return 'technical'

        return 'mixed'


class HybridRetriever:
    """混合检索器主类"""

    def __init__(self, documents: List[str]):
        self.documents = documents
        self.bm25_retriever = BM25Retriever()
        self.dense_retriever = DenseRetriever()
        self.weight_manager = AdaptiveWeightManager()

        # 构建索引
        self.bm25_retriever.fit(documents)
        self.dense_retriever.fit(documents)

    def search(self, query: str, top_k: int = 10, alpha: float = None) -> List[SearchResult]:
        """混合检索主方法"""
        # 自适应权重
        if alpha is None:
            alpha = self.weight_manager.get_optimal_weight(query)

        # 并行执行双路检索
        semantic_results = self.dense_retriever.search(query, top_k=top_k * 2)
        keyword_results = self.bm25_retriever.search(query, top_k=top_k * 2)

        # 结果融合
        combined_results = self.reciprocal_rank_fusion(
            semantic_results,
            keyword_results,
            alpha
        )

        return combined_results[:top_k]

    def reciprocal_rank_fusion(
        self,
        semantic_results: List[SearchResult],
        keyword_results: List[SearchResult],
        alpha: float,
        k: int = 60
    ) -> List[SearchResult]:
        """倒数排名融合算法（RRF）"""
        combined_scores = {}
        doc_objects = {}

        # 语义检索分数
        for rank, result in enumerate(semantic_results):
            doc_id = result.doc_id
            rrf_score = alpha / (k + rank + 1)
            combined_scores[doc_id] = rrf_score
            doc_objects[doc_id] = result

        # 关键词检索分数
        for rank, result in enumerate(keyword_results):
            doc_id = result.doc_id
            rrf_score = (1 - alpha) / (k + rank + 1)

            if doc_id in combined_scores:
                combined_scores[doc_id] += rrf_score
            else:
                combined_scores[doc_id] = rrf_score
                doc_objects[doc_id] = result

        # 按分数排序
        sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)

        # 构建最终结果
        final_results = []
        for doc_id, score in sorted_items:
            result = doc_objects[doc_id]
            final_results.append(SearchResult(
                doc_id=result.doc_id,
                content=result.content,
                score=score,
                metadata=result.metadata,
                source='hybrid'
            ))

        return final_results

    def evaluate_fusion_strategies(
        self,
        query: str,
        strategies: List[str] = None
    ) -> Dict[str, List[SearchResult]]:
        """评估不同融合策略的效果"""
        if strategies is None:
            strategies = ['rrf', 'weighted_sum', 'comb_mnz']

        semantic_results = self.dense_retriever.search(query, top_k=20)
        keyword_results = self.bm25_retriever.search(query, top_k=20)

        results = {}

        for strategy in strategies:
            if strategy == 'rrf':
                results[strategy] = self.reciprocal_rank_fusion(
                    semantic_results, keyword_results, 0.7
                )[:10]
            elif strategy == 'weighted_sum':
                results[strategy] = self._weighted_sum_fusion(
                    semantic_results, keyword_results, 0.7
                )[:10]
            elif strategy == 'comb_mnz':
                results[strategy] = self._comb_mnz_fusion(
                    semantic_results, keyword_results
                )[:10]

        return results

    def _weighted_sum_fusion(
        self,
        semantic_results: List[SearchResult],
        keyword_results: List[SearchResult],
        alpha: float
    ) -> List[SearchResult]:
        """加权和融合策略"""
        combined_scores = {}
        doc_objects = {}

        # 归一化分数
        max_semantic = max((r.score for r in semantic_results), default=1) if semantic_results else 1
        max_keyword = max((r.score for r in keyword_results), default=1) if keyword_results else 1

        # 防止除零错误
        if max_semantic == 0:
            max_semantic = 1
        if max_keyword == 0:
            max_keyword = 1

        # 语义检索分数
        for result in semantic_results:
            doc_id = result.doc_id
            normalized_score = result.score / max_semantic
            combined_scores[doc_id] = alpha * normalized_score
            doc_objects[doc_id] = result

        # 关键词检索分数
        for result in keyword_results:
            doc_id = result.doc_id
            normalized_score = result.score / max_keyword

            if doc_id in combined_scores:
                combined_scores[doc_id] += (1 - alpha) * normalized_score
            else:
                combined_scores[doc_id] = (1 - alpha) * normalized_score
                doc_objects[doc_id] = result

        # 排序
        sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)

        final_results = []
        for doc_id, score in sorted_items:
            result = doc_objects[doc_id]
            final_results.append(SearchResult(
                doc_id=result.doc_id,
                content=result.content,
                score=score,
                metadata=result.metadata,
                source='hybrid'
            ))

        return final_results

    def _comb_mnz_fusion(
        self,
        semantic_results: List[SearchResult],
        keyword_results: List[SearchResult]
    ) -> List[SearchResult]:
        """CombMNZ融合策略"""
        combined_scores = {}
        doc_objects = {}
        match_counts = defaultdict(int)

        # 收集所有结果
        for result in semantic_results:
            doc_id = result.doc_id
            combined_scores[doc_id] = result.score
            doc_objects[doc_id] = result
            match_counts[doc_id] += 1

        for result in keyword_results:
            doc_id = result.doc_id
            if doc_id in combined_scores:
                combined_scores[doc_id] += result.score
                match_counts[doc_id] += 1
            else:
                combined_scores[doc_id] = result.score
                doc_objects[doc_id] = result
                match_counts[doc_id] += 1

        # CombMNZ: score * match_count
        for doc_id in combined_scores:
            combined_scores[doc_id] *= match_counts[doc_id]

        # 排序
        sorted_items = sorted(combined_scores.items(), key=lambda x: x[1], reverse=True)

        final_results = []
        for doc_id, score in sorted_items:
            result = doc_objects[doc_id]
            final_results.append(SearchResult(
                doc_id=result.doc_id,
                content=result.content,
                score=score,
                metadata=result.metadata,
                source='hybrid'
            ))

        return final_results


# 使用示例
if __name__ == "__main__":
    # 模拟文档数据
    documents = [
        "RAG是检索增强生成技术，结合了信息检索和语言生成",
        "BM25是经典的关键词检索算法，基于TF-IDF改进",
        "向量检索通过语义嵌入实现文档相似度计算",
        "混合检索策略可以结合关键词和语义检索的优势",
        "重排序算法进一步提升检索结果的相关性"
    ]

    # 初始化混合检索器
    retriever = HybridRetriever(documents)

    # 执行检索
    query = "RAG检索算法"
    results = retriever.search(query, top_k=3)

    print(f"查询: {query}")
    print(f"检索结果:")
    for i, result in enumerate(results, 1):
        print(f"{i}. [分数: {result.score:.3f}] {result.content[:50]}...")

    # 评估不同融合策略
    print(f"\n融合策略对比:")
    strategy_results = retriever.evaluate_fusion_strategies(query)
    for strategy, results in strategy_results.items():
        print(f"{strategy}: {len(results)} 个结果")

查询: RAG检索算法
检索结果:
1. [分数: 0.016] 重排序算法进一步提升检索结果的相关性...
2. [分数: 0.016] RAG是检索增强生成技术，结合了信息检索和语言生成...
3. [分数: 0.016] 混合检索策略可以结合关键词和语义检索的优势...

融合策略对比:
rrf: 5 个结果
weighted_sum: 5 个结果
comb_mnz: 5 个结果
