<a href="https://colab.research.google.com/github/wangyiyang/RAG-Cookbook-Code/blob/main/ch03/performance_optimizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install sentence-transformers torch transformers faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0.post1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-m

In [2]:
"""
性能优化与缓存策略
提升检索系统的响应速度和吞吐量
"""

import time
import hashlib
import asyncio
from typing import Dict, List, Any, Optional, Callable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from collections import OrderedDict
import threading


@dataclass
class CacheItem:
    """缓存项"""
    value: Any
    timestamp: float
    access_count: int = 0
    ttl: float = 3600  # 默认1小时过期


@dataclass
class PerformanceMetrics:
    """性能指标"""
    total_requests: int = 0
    cache_hits: int = 0
    cache_misses: int = 0
    avg_response_time: float = 0.0
    total_response_time: float = 0.0

    @property
    def cache_hit_rate(self) -> float:
        if self.total_requests == 0:
            return 0.0
        return self.cache_hits / self.total_requests

    def update_response_time(self, response_time: float):
        """更新响应时间统计"""
        self.total_response_time += response_time
        self.total_requests += 1
        self.avg_response_time = self.total_response_time / self.total_requests


class LRUCache:
    """LRU缓存实现"""

    def __init__(self, max_size: int = 1000, default_ttl: float = 3600):
        self.max_size = max_size
        self.default_ttl = default_ttl
        self.cache: OrderedDict[str, CacheItem] = OrderedDict()
        self.lock = threading.RLock()

    def _generate_key(self, *args, **kwargs) -> str:
        """生成缓存键"""
        key_str = str(args) + str(sorted(kwargs.items()))
        return hashlib.md5(key_str.encode()).hexdigest()

    def get(self, key: str) -> Optional[Any]:
        """获取缓存值"""
        with self.lock:
            if key not in self.cache:
                return None

            item = self.cache[key]

            # 检查是否过期
            if time.time() - item.timestamp > item.ttl:
                del self.cache[key]
                return None

            # 更新访问统计
            item.access_count += 1

            # 移到最后（LRU策略）
            self.cache.move_to_end(key)

            return item.value

    def put(self, key: str, value: Any, ttl: Optional[float] = None) -> None:
        """设置缓存值"""
        with self.lock:
            if ttl is None:
                ttl = self.default_ttl

            # 如果已存在，更新值
            if key in self.cache:
                self.cache[key].value = value
                self.cache[key].timestamp = time.time()
                self.cache[key].ttl = ttl
                self.cache.move_to_end(key)
                return

            # 检查容量，删除最旧的项
            while len(self.cache) >= self.max_size:
                oldest_key = next(iter(self.cache))
                del self.cache[oldest_key]

            # 添加新项
            self.cache[key] = CacheItem(
                value=value,
                timestamp=time.time(),
                ttl=ttl
            )

    def clear(self) -> None:
        """清空缓存"""
        with self.lock:
            self.cache.clear()

    def size(self) -> int:
        """获取缓存大小"""
        with self.lock:
            return len(self.cache)

    def cleanup_expired(self) -> int:
        """清理过期项"""
        with self.lock:
            current_time = time.time()
            expired_keys = []

            for key, item in self.cache.items():
                if current_time - item.timestamp > item.ttl:
                    expired_keys.append(key)

            for key in expired_keys:
                del self.cache[key]

            return len(expired_keys)


class QueryCache:
    """查询缓存管理器"""

    def __init__(self, max_size: int = 1000):
        self.query_cache = LRUCache(max_size, default_ttl=1800)  # 30分钟
        self.embedding_cache = LRUCache(max_size * 2, default_ttl=7200)  # 2小时
        self.result_cache = LRUCache(max_size // 2, default_ttl=600)  # 10分钟

    def get_query_result(self, query: str, **kwargs) -> Optional[Any]:
        """获取查询结果缓存"""
        key = self._generate_query_key(query, **kwargs)
        return self.query_cache.get(key)

    def cache_query_result(self, query: str, result: Any, **kwargs) -> None:
        """缓存查询结果"""
        key = self._generate_query_key(query, **kwargs)
        self.query_cache.put(key, result)

    def get_embedding(self, text: str) -> Optional[Any]:
        """获取文本嵌入缓存"""
        key = hashlib.md5(text.encode()).hexdigest()
        return self.embedding_cache.get(key)

    def cache_embedding(self, text: str, embedding: Any) -> None:
        """缓存文本嵌入"""
        key = hashlib.md5(text.encode()).hexdigest()
        self.embedding_cache.put(key, embedding, ttl=7200)

    def _generate_query_key(self, query: str, **kwargs) -> str:
        """生成查询键"""
        key_data = {'query': query, **kwargs}
        key_str = str(sorted(key_data.items()))
        return hashlib.md5(key_str.encode()).hexdigest()

    def get_cache_stats(self) -> Dict[str, Dict]:
        """获取缓存统计"""
        return {
            'query_cache': {
                'size': self.query_cache.size(),
                'max_size': self.query_cache.max_size
            },
            'embedding_cache': {
                'size': self.embedding_cache.size(),
                'max_size': self.embedding_cache.max_size
            },
            'result_cache': {
                'size': self.result_cache.size(),
                'max_size': self.result_cache.max_size
            }
        }


class AsyncRetrievalPipeline:
    """异步检索管道"""

    def __init__(self, max_workers: int = 10):
        self.executor = ThreadPoolExecutor(max_workers=max_workers)
        self.semaphore = asyncio.Semaphore(max_workers)

    async def parallel_retrieval(
        self,
        query: str,
        retrievers: List[Callable],
        top_k: int = 10
    ) -> List[Any]:
        """并行执行多种检索策略"""

        async def run_retriever(retriever_func):
            async with self.semaphore:
                loop = asyncio.get_event_loop()
                return await loop.run_in_executor(
                    self.executor,
                    retriever_func,
                    query,
                    top_k
                )

        # 创建异步任务
        tasks = [run_retriever(retriever) for retriever in retrievers]

        # 并发执行
        results = await asyncio.gather(*tasks, return_exceptions=True)

        # 过滤异常结果
        valid_results = [r for r in results if not isinstance(r, Exception)]

        return valid_results

    async def batch_process(
        self,
        queries: List[str],
        process_func: Callable,
        batch_size: int = 50
    ) -> List[Any]:
        """批量处理查询"""
        results = []

        for i in range(0, len(queries), batch_size):
            batch = queries[i:i + batch_size]

            # 创建批次任务
            tasks = [
                self._process_single_query(query, process_func)
                for query in batch
            ]

            # 执行批次
            batch_results = await asyncio.gather(*tasks, return_exceptions=True)
            results.extend(batch_results)

        return results

    async def _process_single_query(self, query: str, process_func: Callable):
        """处理单个查询"""
        async with self.semaphore:
            loop = asyncio.get_event_loop()
            return await loop.run_in_executor(self.executor, process_func, query)


class RetrievalOptimizer:
    """检索优化器"""

    def __init__(self, cache_size: int = 1000):
        self.cache = QueryCache(cache_size)
        self.metrics = PerformanceMetrics()
        self.async_pipeline = AsyncRetrievalPipeline()

    def optimized_retrieval(
        self,
        query: str,
        retrieval_func: Callable,
        use_cache: bool = True,
        **kwargs
    ) -> Any:
        """优化的检索方法"""
        start_time = time.time()

        try:
            # 1. 尝试从缓存获取
            if use_cache:
                cached_result = self.cache.get_query_result(query, **kwargs)
                if cached_result is not None:
                    self.metrics.cache_hits += 1
                    self.metrics.update_response_time(time.time() - start_time)
                    return cached_result
                else:
                    self.metrics.cache_misses += 1

            # 2. 执行检索
            result = retrieval_func(query, **kwargs)

            # 3. 缓存结果
            if use_cache and result is not None:
                self.cache.cache_query_result(query, result, **kwargs)

            # 4. 更新性能指标
            self.metrics.update_response_time(time.time() - start_time)

            return result

        except Exception as e:
            self.metrics.update_response_time(time.time() - start_time)
            raise e

    def get_embedding_with_cache(
        self,
        text: str,
        embedding_func: Callable
    ) -> Any:
        """带缓存的嵌入计算"""
        # 检查缓存
        cached_embedding = self.cache.get_embedding(text)
        if cached_embedding is not None:
            return cached_embedding

        # 计算嵌入
        embedding = embedding_func(text)

        # 缓存结果
        self.cache.cache_embedding(text, embedding)

        return embedding

    async def optimized_async_retrieval(
        self,
        query: str,
        retrievers: List[Callable],
        top_k: int = 10
    ) -> List[Any]:
        """异步优化检索"""
        return await self.async_pipeline.parallel_retrieval(query, retrievers, top_k)

    def get_performance_report(self) -> Dict[str, Any]:
        """获取性能报告"""
        cache_stats = self.cache.get_cache_stats()

        return {
            'metrics': {
                'total_requests': self.metrics.total_requests,
                'cache_hit_rate': f"{self.metrics.cache_hit_rate:.1%}",
                'avg_response_time': f"{self.metrics.avg_response_time:.3f}s",
                'cache_hits': self.metrics.cache_hits,
                'cache_misses': self.metrics.cache_misses
            },
            'cache_stats': cache_stats,
            'recommendations': self._generate_recommendations()
        }

    def _generate_recommendations(self) -> List[str]:
        """生成优化建议"""
        recommendations = []

        # 缓存命中率建议
        if self.metrics.cache_hit_rate < 0.6:
            recommendations.append("缓存命中率偏低，建议增加缓存容量或调整TTL")

        # 响应时间建议
        if self.metrics.avg_response_time > 0.5:
            recommendations.append("平均响应时间偏高，建议优化检索算法或增加并发")

        # 缓存容量建议
        cache_stats = self.cache.get_cache_stats()
        for cache_name, stats in cache_stats.items():
            usage_rate = stats['size'] / stats['max_size']
            if usage_rate > 0.9:
                recommendations.append(f"{cache_name}使用率过高，建议增加容量")

        return recommendations

    def cleanup_caches(self) -> Dict[str, int]:
        """清理过期缓存"""
        results = {}
        results['query_cache'] = self.cache.query_cache.cleanup_expired()
        results['embedding_cache'] = self.cache.embedding_cache.cleanup_expired()
        results['result_cache'] = self.cache.result_cache.cleanup_expired()
        return results


# 装饰器：自动缓存
def cached_retrieval(cache_ttl: int = 1800):
    """检索缓存装饰器"""
    def decorator(func):
        cache = LRUCache(default_ttl=cache_ttl)

        def wrapper(query: str, *args, **kwargs):
            # 生成缓存键
            key = cache._generate_key(query, *args, **kwargs)

            # 尝试从缓存获取
            result = cache.get(key)
            if result is not None:
                return result

            # 执行函数
            result = func(query, *args, **kwargs)

            # 缓存结果
            cache.put(key, result)

            return result

        wrapper.cache = cache
        return wrapper

    return decorator


# 使用示例
if __name__ == "__main__":

    # 模拟检索函数
    def mock_retrieval(query: str, top_k: int = 10):
        """模拟检索函数"""
        time.sleep(0.1)  # 模拟检索延迟
        return [f"文档{i}: 关于{query}的内容" for i in range(top_k)]

    def mock_embedding(text: str):
        """模拟嵌入函数"""
        time.sleep(0.05)  # 模拟计算延迟
        return [0.1, 0.2, 0.3] * 100  # 模拟向量

    # 初始化优化器
    optimizer = RetrievalOptimizer(cache_size=100)

    # 测试缓存效果
    queries = ["RAG技术", "检索算法", "重排序", "RAG技术"]  # 重复查询测试缓存

    print("测试检索优化...")
    for query in queries:
        start = time.time()
        result = optimizer.optimized_retrieval(query, mock_retrieval, top_k=5)
        duration = time.time() - start
        print(f"查询 '{query}' 耗时: {duration:.3f}s, 结果数: {len(result)}")

    # 测试嵌入缓存
    print("\n测试嵌入缓存...")
    texts = ["RAG原理", "检索技术", "RAG原理"]  # 重复文本测试缓存

    for text in texts:
        start = time.time()
        embedding = optimizer.get_embedding_with_cache(text, mock_embedding)
        duration = time.time() - start
        print(f"嵌入 '{text}' 耗时: {duration:.3f}s, 维度: {len(embedding)}")

    # 性能报告
    print("\n性能报告:")
    report = optimizer.get_performance_report()
    for key, value in report['metrics'].items():
        print(f"  {key}: {value}")

    if report['recommendations']:
        print("\n优化建议:")
        for rec in report['recommendations']:
            print(f"  - {rec}")

    # 清理缓存
    print("\n清理过期缓存:")
    cleanup_results = optimizer.cleanup_caches()
    for cache_name, count in cleanup_results.items():
        print(f"  {cache_name}: 清理了 {count} 个过期项")

测试检索优化...
查询 'RAG技术' 耗时: 0.100s, 结果数: 5
查询 '检索算法' 耗时: 0.100s, 结果数: 5
查询 '重排序' 耗时: 0.100s, 结果数: 5
查询 'RAG技术' 耗时: 0.000s, 结果数: 5

测试嵌入缓存...
嵌入 'RAG原理' 耗时: 0.051s, 维度: 300
嵌入 '检索技术' 耗时: 0.050s, 维度: 300
嵌入 'RAG原理' 耗时: 0.000s, 维度: 300

性能报告:
  total_requests: 4
  cache_hit_rate: 25.0%
  avg_response_time: 0.075s
  cache_hits: 1
  cache_misses: 3

优化建议:
  - 缓存命中率偏低，建议增加缓存容量或调整TTL

清理过期缓存:
  query_cache: 清理了 0 个过期项
  embedding_cache: 清理了 0 个过期项
  result_cache: 清理了 0 个过期项
