In [5]:
import numpy as np
import Levenshtein  # 用于计算编辑距离
import pandas as pd
import re
from typing import List, Dict, Tuple

class ELEAIEvaluator:
    def __init__(self, weights=(0.5, 0.3, 0.2)):
        """
        初始化评估器
        
        Args:
            weights: ASR、DOM和QUE的权重，默认为(0.5, 0.3, 0.2)
        """
        self.weights = weights
        self.dish_database = set()  # 菜品数据库
    
    def load_dish_database(self, dish_file_path):
        """
        加载菜品数据库
        
        Args:
            dish_file_path: 菜品数据文件路径
        """
        try:
            df = pd.read_excel(dish_file_path)
            self.dish_database = set(df['item'].str.strip().tolist())
            print(f"成功加载 {len(self.dish_database)} 个菜品")
        except Exception as e:
            print(f"加载菜品数据库失败: {e}")
    
    def preprocess_text(self, text):
        """
        预处理文本，去除标点符号
        
        Args:
            text: 输入文本
            
        Returns:
            处理后的文本
        """
        if not isinstance(text, str):
            return ""
        # 去除标点符号
        return re.sub(r'[^\w\s]', '', text)
    
    def calculate_cer(self, reference, hypothesis):
        """
        计算字错率(CER)
        
        Args:
            reference: 参考文本(标准答案)
            hypothesis: 假设文本(模型输出)
            
        Returns:
            字错率
        """
        # 预处理文本，去除标点符号
        reference = self.preprocess_text(reference)
        hypothesis = self.preprocess_text(hypothesis)
        
        if len(reference) == 0:
            return 1.0 if len(hypothesis) > 0 else 0.0
        
        # 计算编辑距离
        edit_distance = Levenshtein.distance(reference, hypothesis)
        
        # 计算CER
        cer = edit_distance / len(reference)
        return cer
    
    def evaluate_asr(self, references, hypotheses):
        """
        评估ASR性能
        
        Args:
            references: 参考文本列表
            hypotheses: 假设文本列表
            
        Returns:
            ASR得分
        """
        if len(references) != len(hypotheses):
            raise ValueError("参考文本和假设文本数量不一致")
        
        cers = []
        for ref, hyp in zip(references, hypotheses):
            cer = self.calculate_cer(ref, hyp)
            cers.append(cer)
        
        # 计算ASR得分
        asr_score = 1 - np.mean(cers)
        return asr_score
    
    def evaluate_dom(self, true_labels, predicted_labels):
        """
        评估领域分类(DOM)性能
        
        Args:
            true_labels: 真实标签列表(0或1)
            predicted_labels: 预测标签列表(0或1)
            
        Returns:
            DOM得分
        """
        if len(true_labels) != len(predicted_labels):
            raise ValueError("真实标签和预测标签数量不一致")
        
        # 计算准确率
        correct = sum(1 for t, p in zip(true_labels, predicted_labels) if t == p)
        dom_score = correct / len(true_labels)
        
        return dom_score
    
    def evaluate_que(self, queries, domain_predictions, model_scores):
        """
        评估查询改写(QUE)性能
        
        Args:
            queries: 改写后的查询列表
            domain_predictions: 领域预测结果列表(0或1)
            model_scores: 模型评分列表(0-5)
            
        Returns:
            QUE得分
        """
        if not (len(queries) == len(domain_predictions) == len(model_scores)):
            raise ValueError("查询、领域预测和模型评分数量不一致")
        
        # 提取预测为外卖领域的样本
        valid_indices = [i for i, pred in enumerate(domain_predictions) if pred == 1]
        
        if not valid_indices:
            return 0.0  # 没有预测为外卖领域的样本
        
        total_score = 0
        for idx in valid_indices:
            query = queries[idx]
            score = model_scores[idx]
            
            # 检查菜品是否在菜品库中
            dishes_in_query = self.extract_dishes(query)
            q_i = 1 if any(dish in self.dish_database for dish in dishes_in_query) else 0
            
            # 计算单个查询的得分
            query_score = (q_i * score) / 5
            total_score += query_score
        
        # 计算QUE得分
        que_score = total_score / len(valid_indices)
        return que_score
    
    def extract_dishes(self, query):
        """
        从查询中提取可能的菜品名称
        简单实现，实际应用中可能需要更复杂的NLP方法
        
        Args:
            query: 查询文本
            
        Returns:
            可能的菜品列表
        """
        # 这里使用简单的方法，假设查询中"点xxx"或"要xxx"中的xxx是菜品
        dishes = []
        patterns = [r'点([^，。!?]+)', r'要([^，。!?]+)']
        
        for pattern in patterns:
            matches = re.findall(pattern, query)
            dishes.extend(matches)
        
        return dishes
    
    def calculate_total_score(self, asr_score, dom_score, que_score):
        """
        计算总得分
        
        Args:
            asr_score: ASR得分
            dom_score: DOM得分
            que_score: QUE得分
            
        Returns:
            总得分
        """
        total_score = (
            self.weights[0] * asr_score + 
            self.weights[1] * dom_score + 
            self.weights[2] * que_score
        )
        return total_score
    
    def evaluate(self, test_data):
        """
        评估完整测试数据
        
        Args:
            test_data: 测试数据，包含ASR、DOM和QUE的评估数据
            
        Returns:
            各项得分和总分
        """
        # 评估ASR
        asr_score = self.evaluate_asr(
            test_data['asr_references'], 
            test_data['asr_hypotheses']
        )
        
        # 评估DOM
        dom_score = self.evaluate_dom(
            test_data['dom_true_labels'], 
            test_data['dom_predicted_labels']
        )
        
        # 评估QUE
        que_score = self.evaluate_que(
            test_data['que_queries'], 
            test_data['dom_predicted_labels'],  # 使用DOM的预测结果
            test_data['que_model_scores']
        )
        
        # 计算总分
        total_score = self.calculate_total_score(asr_score, dom_score, que_score)
        
        return {
            'asr_score': asr_score,
            'dom_score': dom_score,
            'que_score': que_score,
            'total_score': total_score
        }


# 使用示例
def main():
    # 初始化评估器
    evaluator = ELEAIEvaluator()
    
    # 加载菜品数据库
    evaluator.load_dish_database('../data/商品.xlsx')
    
    # 准备测试数据
    test_data = {
        # ASR测试数据
        'asr_references': [
            "我在减肥，帮我点个不长胖的外卖",
            "给我来个酸辣土豆丝，不要辣的",
            "我想吃麻辣烫，要特别辣的那种"
        ],
        'asr_hypotheses': [
            "我在减肥帮我点个不长胖的外卖",
            "给我来个酸辣土豆丝不要辣的",
            "我想吃麻辣烫要特别辣的那种"
        ],
        
        # DOM测试数据
        'dom_true_labels': [1, 1, 1],
        'dom_predicted_labels': [1, 1, 0],
        
        # QUE测试数据
        'que_queries': [
            "帮我点沙拉",
            "帮我点酸辣土豆丝",
            "帮我点麻辣烫"
        ],
        'que_model_scores': [4, 5, 3]  # 模型给出的评分(0-5)
    }
    
    # 进行评估
    results = evaluator.evaluate(test_data)
    
    # 打印结果
    print("\n===== 评估结果 =====")
    print(f"ASR得分: {results['asr_score']:.4f}")
    print(f"DOM得分: {results['dom_score']:.4f}")
    print(f"QUE得分: {results['que_score']:.4f}")
    print(f"总分: {results['total_score']:.4f}")


if __name__ == "__main__":
    main()


成功加载 3427 个菜品

===== 评估结果 =====
ASR得分: 1.0000
DOM得分: 0.6667
QUE得分: 0.5000
总分: 0.8000
