<a href="https://colab.research.google.com/github/wangyiyang/RAG-Cookbook-Code/blob/main/ch03/evaluation_metrics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
%pip install sentence-transformers torch transformers faiss-cpu

Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0.post1-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (5.0 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-m

In [3]:
"""
评估指标与效果监控
提供全面的检索系统评估方法
"""

import numpy as np
from typing import List, Dict, Set, Tuple, Any, Optional
from dataclasses import dataclass
import math
from collections import defaultdict
import time


@dataclass
class EvaluationResult:
    """评估结果"""
    precision: float
    recall: float
    f1_score: float
    ndcg: float
    map_score: float  # Mean Average Precision
    mrr: float       # Mean Reciprocal Rank
    hit_rate: float
    query_id: str = ""

    def to_dict(self) -> Dict[str, float]:
        return {
            'precision': self.precision,
            'recall': self.recall,
            'f1_score': self.f1_score,
            'ndcg': self.ndcg,
            'map_score': self.map_score,
            'mrr': self.mrr,
            'hit_rate': self.hit_rate
        }


@dataclass
class PerformanceStats:
    """性能统计"""
    avg_response_time: float
    p95_response_time: float
    p99_response_time: float
    throughput: float  # QPS
    error_rate: float

    def to_dict(self) -> Dict[str, float]:
        return {
            'avg_response_time': self.avg_response_time,
            'p95_response_time': self.p95_response_time,
            'p99_response_time': self.p99_response_time,
            'throughput': self.throughput,
            'error_rate': self.error_rate
        }


class RetrievalEvaluator:
    """检索系统评估器"""

    def __init__(self):
        self.response_times = []
        self.error_count = 0
        self.total_requests = 0

    def evaluate_single_query(
        self,
        retrieved_docs: List[str],
        relevant_docs: List[str],
        query_id: str = ""
    ) -> EvaluationResult:
        """评估单个查询的结果"""

        retrieved_set = set(retrieved_docs)
        relevant_set = set(relevant_docs)

        # 计算基础指标
        true_positives = len(retrieved_set & relevant_set)
        precision = true_positives / len(retrieved_set) if retrieved_set else 0.0
        recall = true_positives / len(relevant_set) if relevant_set else 0.0
        f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

        # 计算NDCG
        ndcg = self._calculate_ndcg(retrieved_docs, relevant_docs)

        # 计算MAP
        map_score = self._calculate_average_precision(retrieved_docs, relevant_docs)

        # 计算MRR
        mrr = self._calculate_reciprocal_rank(retrieved_docs, relevant_docs)

        # 计算Hit Rate
        hit_rate = 1.0 if true_positives > 0 else 0.0

        return EvaluationResult(
            precision=precision,
            recall=recall,
            f1_score=f1_score,
            ndcg=ndcg,
            map_score=map_score,
            mrr=mrr,
            hit_rate=hit_rate,
            query_id=query_id
        )

    def evaluate_batch(
        self,
        queries: List[str],
        retrieved_results: List[List[str]],
        ground_truth: List[List[str]]
    ) -> Dict[str, float]:
        """批量评估多个查询"""

        if len(queries) != len(retrieved_results) != len(ground_truth):
            raise ValueError("查询、检索结果和真实标签的数量不匹配")

        all_results = []

        for i, (query, retrieved, relevant) in enumerate(zip(queries, retrieved_results, ground_truth)):
            result = self.evaluate_single_query(retrieved, relevant, query_id=str(i))
            all_results.append(result)

        # 计算平均指标
        avg_metrics = self._calculate_average_metrics(all_results)

        return avg_metrics

    def _calculate_ndcg(self, retrieved_docs: List[str], relevant_docs: List[str], k: int = 10) -> float:
        """计算NDCG@k"""
        if not relevant_docs:
            return 0.0

        # 计算DCG
        dcg = 0.0
        relevant_set = set(relevant_docs)

        for i, doc in enumerate(retrieved_docs[:k]):
            if doc in relevant_set:
                # 相关性得分（简化为1）
                relevance = 1.0
                dcg += relevance / math.log2(i + 2)  # i+2 因为log2(1)=0

        # 计算IDCG（理想DCG）
        idcg = 0.0
        for i in range(min(len(relevant_docs), k)):
            idcg += 1.0 / math.log2(i + 2)

        return dcg / idcg if idcg > 0 else 0.0

    def _calculate_average_precision(self, retrieved_docs: List[str], relevant_docs: List[str]) -> float:
        """计算Average Precision"""
        if not relevant_docs:
            return 0.0

        relevant_set = set(relevant_docs)
        num_relevant = 0
        precision_sum = 0.0

        for i, doc in enumerate(retrieved_docs):
            if doc in relevant_set:
                num_relevant += 1
                precision_at_i = num_relevant / (i + 1)
                precision_sum += precision_at_i

        return precision_sum / len(relevant_docs) if relevant_docs else 0.0

    def _calculate_reciprocal_rank(self, retrieved_docs: List[str], relevant_docs: List[str]) -> float:
        """计算Reciprocal Rank"""
        relevant_set = set(relevant_docs)

        for i, doc in enumerate(retrieved_docs):
            if doc in relevant_set:
                return 1.0 / (i + 1)

        return 0.0

    def _calculate_average_metrics(self, results: List[EvaluationResult]) -> Dict[str, float]:
        """计算平均指标"""
        if not results:
            return {}

        metrics = {
            'precision': np.mean([r.precision for r in results]),
            'recall': np.mean([r.recall for r in results]),
            'f1_score': np.mean([r.f1_score for r in results]),
            'ndcg': np.mean([r.ndcg for r in results]),
            'map': np.mean([r.map_score for r in results]),
            'mrr': np.mean([r.mrr for r in results]),
            'hit_rate': np.mean([r.hit_rate for r in results])
        }

        return metrics

    def record_response_time(self, response_time: float) -> None:
        """记录响应时间"""
        self.response_times.append(response_time)
        self.total_requests += 1

    def record_error(self) -> None:
        """记录错误"""
        self.error_count += 1
        self.total_requests += 1

    def get_performance_stats(self, time_window: float = 60.0) -> PerformanceStats:
        """获取性能统计"""
        if not self.response_times:
            return PerformanceStats(0, 0, 0, 0, 0)

        response_times = np.array(self.response_times)

        avg_response_time = np.mean(response_times)
        p95_response_time = np.percentile(response_times, 95)
        p99_response_time = np.percentile(response_times, 99)

        # 计算吞吐量（QPS）
        throughput = len(self.response_times) / time_window if time_window > 0 else 0

        # 计算错误率
        error_rate = self.error_count / self.total_requests if self.total_requests > 0 else 0

        return PerformanceStats(
            avg_response_time=avg_response_time,
            p95_response_time=p95_response_time,
            p99_response_time=p99_response_time,
            throughput=throughput,
            error_rate=error_rate
        )

    def reset_stats(self) -> None:
        """重置统计信息"""
        self.response_times = []
        self.error_count = 0
        self.total_requests = 0


class ABTestEvaluator:
    """A/B测试评估器"""

    def __init__(self):
        self.group_a_evaluator = RetrievalEvaluator()
        self.group_b_evaluator = RetrievalEvaluator()

    def compare_systems(
        self,
        queries: List[str],
        system_a_results: List[List[str]],
        system_b_results: List[List[str]],
        ground_truth: List[List[str]]
    ) -> Dict[str, Any]:
        """比较两个检索系统"""

        # 评估系统A
        metrics_a = self.group_a_evaluator.evaluate_batch(queries, system_a_results, ground_truth)

        # 评估系统B
        metrics_b = self.group_b_evaluator.evaluate_batch(queries, system_b_results, ground_truth)

        # 计算改进百分比
        improvements = {}
        for metric in metrics_a:
            if metrics_a[metric] > 0:
                improvement = (metrics_b[metric] - metrics_a[metric]) / metrics_a[metric] * 100
                improvements[metric] = improvement
            else:
                improvements[metric] = 0.0

        # 统计学显著性检验（简化版）
        significance_tests = self._simple_significance_test(
            system_a_results, system_b_results, ground_truth
        )

        return {
            'system_a_metrics': metrics_a,
            'system_b_metrics': metrics_b,
            'improvements': improvements,
            'significance_tests': significance_tests,
            'winner': self._determine_winner(metrics_a, metrics_b)
        }

    def _simple_significance_test(
        self,
        results_a: List[List[str]],
        results_b: List[List[str]],
        ground_truth: List[List[str]]
    ) -> Dict[str, bool]:
        """简化的显著性检验"""
        # 这里只做简单的差异检验，实际应用中需要更严格的统计检验

        ndcg_a = []
        ndcg_b = []

        for i in range(len(results_a)):
            ndcg_a.append(self.group_a_evaluator._calculate_ndcg(results_a[i], ground_truth[i]))
            ndcg_b.append(self.group_b_evaluator._calculate_ndcg(results_b[i], ground_truth[i]))

        # 简单的t检验逻辑（实际应用建议使用scipy.stats）
        mean_diff = np.mean(ndcg_b) - np.mean(ndcg_a)
        std_diff = np.std(np.array(ndcg_b) - np.array(ndcg_a))

        # 简化的显著性判断
        is_significant = abs(mean_diff) > 2 * std_diff / math.sqrt(len(ndcg_a))

        return {
            'ndcg_significant': is_significant,
            'mean_difference': mean_diff,
            'confidence': 0.95 if is_significant else 0.8
        }

    def _determine_winner(
        self,
        metrics_a: Dict[str, float],
        metrics_b: Dict[str, float]
    ) -> str:
        """确定获胜系统"""
        key_metrics = ['ndcg', 'map', 'mrr', 'f1_score']

        b_wins = 0
        a_wins = 0

        for metric in key_metrics:
            if metric in metrics_a and metric in metrics_b:
                if metrics_b[metric] > metrics_a[metric]:
                    b_wins += 1
                elif metrics_a[metric] > metrics_b[metric]:
                    a_wins += 1

        if b_wins > a_wins:
            return "System B"
        elif a_wins > b_wins:
            return "System A"
        else:
            return "Tie"


class QualityMonitor:
    """质量监控器"""

    def __init__(self, alert_thresholds: Dict[str, float] = None):
        self.thresholds = alert_thresholds or {
            'precision': 0.8,
            'recall': 0.7,
            'ndcg': 0.75,
            'response_time': 0.5,
            'error_rate': 0.05
        }
        self.alerts = []
        self.evaluator = RetrievalEvaluator()

    def monitor_query(
        self,
        query: str,
        retrieved_docs: List[str],
        relevant_docs: List[str],
        response_time: float
    ) -> Dict[str, Any]:
        """监控单个查询"""

        # 记录性能
        self.evaluator.record_response_time(response_time)

        # 评估质量
        result = self.evaluator.evaluate_single_query(retrieved_docs, relevant_docs, query)

        # 检查告警
        alerts = self._check_alerts(result, response_time)

        return {
            'query': query,
            'metrics': result.to_dict(),
            'response_time': response_time,
            'alerts': alerts
        }

    def _check_alerts(self, result: EvaluationResult, response_time: float) -> List[str]:
        """检查告警条件"""
        alerts = []

        if result.precision < self.thresholds['precision']:
            alerts.append(f"精确率过低: {result.precision:.3f} < {self.thresholds['precision']}")

        if result.recall < self.thresholds['recall']:
            alerts.append(f"召回率过低: {result.recall:.3f} < {self.thresholds['recall']}")

        if result.ndcg < self.thresholds['ndcg']:
            alerts.append(f"NDCG过低: {result.ndcg:.3f} < {self.thresholds['ndcg']}")

        if response_time > self.thresholds['response_time']:
            alerts.append(f"响应时间过长: {response_time:.3f}s > {self.thresholds['response_time']}s")

        # 记录告警
        for alert in alerts:
            self.alerts.append({
                'timestamp': time.time(),
                'message': alert,
                'severity': 'warning'
            })

        return alerts

    def get_health_status(self) -> Dict[str, Any]:
        """获取系统健康状态"""
        perf_stats = self.evaluator.get_performance_stats()

        recent_alerts = [
            alert for alert in self.alerts
            if time.time() - alert['timestamp'] < 3600  # 最近1小时
        ]

        health_score = self._calculate_health_score(perf_stats)

        return {
            'health_score': health_score,
            'performance': perf_stats.to_dict(),
            'recent_alerts': len(recent_alerts),
            'alert_rate': len(recent_alerts) / 60,  # 每分钟告警数
            'status': 'healthy' if health_score > 0.8 else 'warning' if health_score > 0.6 else 'critical'
        }

    def _calculate_health_score(self, perf_stats: PerformanceStats) -> float:
        """计算健康分数"""
        score = 1.0

        # 响应时间扣分
        if perf_stats.avg_response_time > self.thresholds['response_time']:
            score -= 0.2

        # 错误率扣分
        if perf_stats.error_rate > self.thresholds['error_rate']:
            score -= 0.3

        # 告警频率扣分
        recent_alert_count = len([
            a for a in self.alerts
            if time.time() - a['timestamp'] < 3600
        ])
        if recent_alert_count > 10:
            score -= 0.2

        return max(0.0, score)


# 使用示例
if __name__ == "__main__":

    # 模拟数据
    queries = ["RAG技术", "检索算法", "重排序方法"]

    # 系统A的检索结果
    system_a_results = [
        ["doc1", "doc2", "doc3"],
        ["doc4", "doc5", "doc6"],
        ["doc7", "doc8", "doc9"]
    ]

    # 系统B的检索结果
    system_b_results = [
        ["doc1", "doc3", "doc2"],
        ["doc4", "doc6", "doc5"],
        ["doc8", "doc7", "doc9"]
    ]

    # 真实相关文档
    ground_truth = [
        ["doc1", "doc2"],
        ["doc4", "doc5"],
        ["doc7", "doc8"]
    ]

    # 单次评估
    evaluator = RetrievalEvaluator()
    result = evaluator.evaluate_single_query(
        ["doc1", "doc2", "doc3"],
        ["doc1", "doc2"],
        "query1"
    )
    print("单次评估结果:")
    for metric, value in result.to_dict().items():
        print(f"  {metric}: {value:.3f}")

    # 批量评估
    print("\n批量评估结果:")
    batch_metrics = evaluator.evaluate_batch(queries, system_a_results, ground_truth)
    for metric, value in batch_metrics.items():
        print(f"  {metric}: {value:.3f}")

    # A/B测试
    print("\nA/B测试结果:")
    ab_evaluator = ABTestEvaluator()
    comparison = ab_evaluator.compare_systems(
        queries, system_a_results, system_b_results, ground_truth
    )

    print(f"获胜者: {comparison['winner']}")
    print("改进情况:")
    for metric, improvement in comparison['improvements'].items():
        print(f"  {metric}: {improvement:+.1f}%")

    # 质量监控
    print("\n质量监控:")
    monitor = QualityMonitor()

    for i, query in enumerate(queries):
        monitor_result = monitor.monitor_query(
            query,
            system_a_results[i],
            ground_truth[i],
            response_time=0.2 + i * 0.1
        )

        if monitor_result['alerts']:
            print(f"查询 '{query}' 触发告警:")
            for alert in monitor_result['alerts']:
                print(f"  - {alert}")

    # 健康状态
    health = monitor.get_health_status()
    print(f"\n系统健康状态: {health['status']}")
    print(f"健康分数: {health['health_score']:.2f}")
    print(f"最近告警数: {health['recent_alerts']}")

单次评估结果:
  precision: 0.667
  recall: 1.000
  f1_score: 0.800
  ndcg: 1.000
  map_score: 1.000
  mrr: 1.000
  hit_rate: 1.000

批量评估结果:
  precision: 0.667
  recall: 1.000
  f1_score: 0.800
  ndcg: 1.000
  map: 1.000
  mrr: 1.000
  hit_rate: 1.000

A/B测试结果:
获胜者: System A
改进情况:
  precision: +0.0%
  recall: +0.0%
  f1_score: +0.0%
  ndcg: -5.4%
  map: -11.1%
  mrr: +0.0%
  hit_rate: +0.0%

质量监控:
查询 'RAG技术' 触发告警:
  - 精确率过低: 0.667 < 0.8
查询 '检索算法' 触发告警:
  - 精确率过低: 0.667 < 0.8
查询 '重排序方法' 触发告警:
  - 精确率过低: 0.667 < 0.8

系统健康状态: healthy
健康分数: 1.00
最近告警数: 3
