In [1]:
import numpy as np 
import pandas as pd
from collections import Counter
from pandasql import sqldf
import re
import os 
os.environ['CUDA_VISIBLE_DEVICES']='1'
import joblib
import json 
import jieba 
import copy

from openai import AzureOpenAI
import matplotlib.pyplot as plt 
from sentence_transformers import SentenceTransformer
from FlagEmbedding import FlagReranker
from transformers import AutoTokenizer
from sklearn.metrics import average_precision_score

In [2]:
# pandasql查询函数需要的环境
pysqldf = lambda q: sqldf(q, globals())

In [3]:
# 原始数据处理
def format_model(x):
    model_list = x.split(',')
    model_list = [i.strip().lower() for i in model_list]
    new_list = [model_list[0]]
    i = 1
    while i < len(model_list):
        if (i != len(model_list) - 1) and (model_list[i-1] == model_list[i]):
            new_list.append(model_list[i]+model_list[i+1])
            if i < len(model_list) - 1:
                i += 2
            else:
                break
        elif (i != len(model_list) - 1) and (model_list[i-1] != model_list[i]):
            new_list.append(model_list[i])
            i += 1
        elif (model_list[i] == "上下水") or (model_list[i] == "air"):
            for j in range(len(new_list)):
                if model_list[i-1] == new_list[j]:
                    new_list.pop(j)
                    break
            new_list.append(model_list[i-1]+model_list[i])
            i += 1
        else:
            new_list.append(model_list[i])
            break
    return new_list

def format_all_models(x, dim_df):
    new_list = []
    for i in x:
        if i.find("全型号") >= 0:
            end_idx = i.find("全型号")
            name = i[:end_idx]
            new_list += [j for j in dim_df[dim_df['cat_name'] == name].model.tolist() if j not in x]
        else:
            new_list.append(i)
    return new_list

def format_series(x, dim_df):
    def contains_chinese(s):
        return re.search('[\u4e00-\u9fff]', s) is not None
    new_list = []
    for i in x:
        if i.find("系列") >= 0:
            end_idx = i.find("系列")
            name = i[:end_idx]
            new_list += [j for j in dim_df[(dim_df.model.str.find(name)>=0) & (
                dim_df.model.apply(lambda x: not contains_chinese(x)))].model.tolist() if j not in x]
            new_list += [i]
        else:
            new_list.append(i)
    return new_list

In [4]:
# 拼接openai embedding
def generate_embeddings(text, model="text-embedding-ada-002"): # model = "deployment_name"
    return client.embeddings.create(input=[text], model=model).data[0].embedding

In [5]:
# 测试集处理及计算与正确qa的相似度
def cosine_similarity(a, b):
    return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))

def search_docs(df, user_query, top_n=4, to_print=True):
    embedding = generate_embeddings(
        user_query,
    )
    df["similarities"] = df.ada_002.apply(lambda x: cosine_similarity(x, embedding))

    res = (
        df.sort_values("similarities", ascending=False)
        .head(top_n)
    )[["qa_id", "question", "answer", "similarities"]]
    return res.to_dict(orient='records')

def concat(x):
    return ",".join(x.astype(str).tolist())

def format_gt(x):
    if str(x) == "nan":
        return x
    else:
        return ",".join(x.split("\n"))

def count_gt(x):
    if str(x) == "nan":
        return 0
    else:
        return len(x.split(","))   

In [6]:
# 向量召回
# def search_docs_bge(df, user_query, top_n=4, to_print=True):
#     embedding = model.encode(user_query, normalize_embeddings=True).tolist()
#     df["similarities"] = df.bge_large.apply(lambda x: cosine_similarity(x, embedding))
#     output_columns = ["qa_id", "question", "answer", "similarities"]
#     if "hit_reason" in df.columns:
#         output_columns.append("hit_reason")
#     res = (
#         df.sort_values("similarities", ascending=False)
#         .head(top_n)
#     )[output_columns]
#     return res.to_dict(orient='records')

def search_docs_bge(df, user_query, top_n=4, to_print=True, add_instruction=False):
    df = df.copy()
    # if add_instruction:
    #     instruction = "为这个句子生成表示以用于检索相关文章："
    #     user_query = instruction+user_query
    embedding = model.encode_queries(user_query, max_length=96).tolist()
    df["similarities"] = df.bge_large.apply(lambda x: cosine_similarity(x, embedding))
    output_columns = ["qa_id", "question", "answer", "similarities"]
    if "hit_reason" in df.columns:
        output_columns.append("hit_reason")
    res = (
        df.sort_values("similarities", ascending=False)
        .head(top_n)
    )[output_columns]
    return res.to_dict(orient='records')

def find_non_chinese_substrings(s):
    # 正则表达式解释：
    # [^\u4e00-\u9fff\W]+ 匹配非中文字符和非ASCII标点的连续字符
    # 但这样会排除空格，所以我们需要允许空格存在
    # 我们使用(?:[^\u4e00-\u9fff\W]| )+ 来实现这一点，(?:) 是非捕获组，用于匹配模式但不作为捕获结果返回
    # [^\u4e00-\u9fff\W] 匹配非中文且非标点的字符，| 表示或，空格 ' ' 被显式允许
    pattern = r'(?:[^\u4e00-\u9fff\W]| )+'
    
    # 使用findall方法查找所有匹配项
    matches = re.findall(pattern, s)
    
    # 过滤掉只包含空格的字符串
    matches = [match for match in matches if not match.isspace()]
    
    return matches

def clean_string(s):
    s = s.replace(" ", "").lower()
    return s

def find_model(x, all_model_list):
    x = x.replace("\n", "") 
    x = find_non_chinese_substrings(x)
    result = [clean_string(s) for s in x]
    return [model for model in all_model_list if model in result]

def find_cat(x, all_cat_list):
    return [name for name in all_cat_list if name in x]   

def filter_model(x, model_list):
    x = x.split(",")
    for model in model_list:
        if model in x:
            return True
    return False

def find_error_with_reason(a):
    # 第一次匹配“错误xxx”
    pattern1 = r"错误\s*\d+"
    matches1 = re.findall(pattern1, a)
    
    # 第二次匹配“错误原因xxx”
    pattern2 = r"错误原因\s*\d+"
    matches2 = re.findall(pattern2, a)

    # 合并两次匹配的结果
    matches = matches1 + matches2
    
    return [name.replace(" ", "").replace("原因", "") for name in matches]

def filter_reason(x, query_reason_list):
    reason_list = find_error_with_reason(x)
    for name in query_reason_list:
        if name in reason_list:
            return True 
    return False

def transform_model_name(x, all_model_list):
    x = x.replace("\n", "") 
    candidates = find_non_chinese_substrings(x)
    for name in candidates:
        cleaned_name = clean_string(name)
        for model in all_model_list:
            if cleaned_name == model:
                x = x.replace(name, model)
                break
    return x 

def remove_model_name(x, all_model_list):
    x = x.replace("\n", "") 
    candidates = find_non_chinese_substrings(x)
    for name in candidates:
        if clean_string(name) in all_model_list:
            x = x.replace(name, "")
    return x 

class BM25_Model(object):
    def __init__(self, documents_list, k1=2, k2=1, b=0.5):
        self.documents_list = documents_list
        self.documents_number = len(documents_list)
        self.avg_documents_len = sum([len(document) for document in documents_list]) / self.documents_number
        self.f = []
        self.idf = {}
        self.k1 = k1
        self.k2 = k2
        self.b = b
        self.init()

    def init(self):
        df = {}
        for document in self.documents_list:
            temp = {}
            for word in document:
                temp[word] = temp.get(word, 0) + 1
            self.f.append(temp)
            for key in temp.keys():
                df[key] = df.get(key, 0) + 1
        for key, value in df.items():
            self.idf[key] = np.log((self.documents_number - value + 0.5) / (value + 0.5))

    def get_score(self, index, query):
        score = 0.0
        document_len = len(self.f[index])
        qf = Counter(query)
        for q in query:
            if q not in self.f[index]:
                continue
            score += self.idf[q] * (self.f[index][q] * (self.k1 + 1) / (
                        self.f[index][q] + self.k1 * (1 - self.b + self.b * document_len / self.avg_documents_len))) * (
                                 qf[q] * (self.k2 + 1) / (qf[q] + self.k2))

        return score

    def get_documents_score(self, query, indices):
        score_list = []
        for i in indices:
            score_list.append(self.get_score(i, query))
        return score_list


class WordCut:
    def __init__(self, all_model_list=None):
        with open('/data/dataset/kefu/hit_stopwords.txt', encoding='utf-8') as f: # 可根据需要打开停用词库，然后加上不想显示的词语
            con = f.readlines()
            stop_words = set()
            for i in con:
                i = i.replace("\n", "")   # 去掉读取每一行数据的\n
                stop_words.add(i)
        self.stop_words = stop_words
        self.all_model_list = all_model_list
        
    def cut(self, mytext):
        # jieba.load_userdict('自定义词典.txt')  # 这里你可以添加jieba库识别不了的网络新词，避免将一些新词拆开
        # jieba.initialize()  # 初始化jieba
        # 文本预处理 ：去除一些无用的字符只提取出中文出来
        # new_data = re.findall('[\u4e00-\u9fa5]+', mytext, re.S)
        # new_data = " ".join(new_data)
        # 匹配中英文标点符号，以及全角和半角符号
        pattern = r'[\u3000-\u303f\uff01-\uff0f\uff1a-\uff20\uff3b-\uff40\uff5b-\uff65\u2018\u2019\u201c\u201d\u2026\u00a0\u2022\u2013\u2014\u2010\u2027\uFE10-\uFE1F\u3001-\u301E]|[\.,!¡?¿\-—_(){}[\]\'\";:/]'
        # 使用 re.sub 替换掉符合模式的字符为空字符
        new_data = re.sub(pattern, '', mytext)
        new_data = transform_model_name(new_data, self.all_model_list)
        # 文本分词
        seg_list_exact = jieba.lcut(new_data)
        result_list = []
        # 去除停用词并且去除单字
        for word in seg_list_exact:
            if word not in self.stop_words and len(word) > 1:
                result_list.append(word) 
        return result_list

def search_docs_bm25(df, indices, user_query, top_n=4):
    # document_list = [wc.cut(doc) for doc in df.question]
    # bm25_model = BM25_Model(document_list)
    embedding = wc.cut(user_query)
    df["similarities"] = bm25_model.get_documents_score(embedding, indices)
    output_columns = ["qa_id", "question", "answer", "similarities"]
    if "hit_reason" in df.columns:
        output_columns.append("hit_reason")
    res = (
        df.sort_values("similarities", ascending=False)
        .head(top_n)
    )[output_columns]
    return res.to_dict(orient='records')

In [7]:
# 综合分析
def ranking_metric(x):
    if (x.find("error")>=0) and (x.find("model")>=0):
        return 1 
    elif (x.find("error")>=0) and (x.find("cat")>=0):
        return 2 
    elif (x.find("error")>=0):
        return 3 
    elif (x.find("model")>=0):
        return 4
    elif (x.find("cat")>=0):
        return 5
    else:
        return 6

def mine_hard_negative(x, similarities, reason, result, positive):
    positives = x[positive].split(",")
    df = pd.DataFrame(x[[similarities, reason, result]].to_dict())
    df = df[~df[result].isin(positives)]
    df["ranking"] = df[reason].apply(lambda x: ranking_metric(x))
    df = df.sort_values(["ranking", similarities], ascending=[True, False])
    df = df.drop_duplicates(result)
    df = df.iloc[:10]
    return pd.Series({col+"_hard": df[col].values.tolist() for col in [similarities, reason, result]})

def format_result(x, similarities, reason, result):
    num_result = len(x[result])
    new_set = dict()
    for j in range(num_result):
        result_name = x[result][j]
        sim = x[similarities][j]
        reason_code = x[reason][j]
        if result_name in new_set:
            if (ranking_metric(reason_code) <= ranking_metric(new_set[result_name]["reason"])
               ) & (sim > new_set[result_name]["similarities"]):
                new_set.update({result_name: {"similarities": sim, "reason": reason_code}})
        else:
            new_set.update({result_name: {"similarities": sim, "reason": reason_code}})
    df = pd.DataFrame(new_set).T.reset_index().rename(columns={"index": "result"})
    df["ranking"] = df["reason"].apply(lambda x: ranking_metric(x))
    df = df.sort_values(["ranking", "similarities"], ascending=[True, False])
    return df[['result', 'reason', 'similarities']].to_dict(orient='records')

def merge_recall(x, recall_list, weights):
    pool = {}
    for i in range(len(weights)):
        weight = weights[i] 
        recall_name = recall_list[i]
        num_results = len(x[recall_name])
        for j in range(num_results):
            result_item = x[recall_name][j]
            result = result_item["result"]
            if result in pool:
                if ranking_metric(result_item['reason']) < ranking_metric(pool[result]['reason']):
                    reason = result_item['reason']
                    pool[result]['reason'] = reason
                pool[result]['similarities'] += weight * result_item['similarities'] / sum(weights)
                pool[result]['full_reason'] = pool[result]['full_reason']+","+result_item['reason']+"_"+recall_name
            else:
                pool[result]= {
                    "reason": result_item['reason'],
                    "similarities": weight * result_item['similarities'] / sum(weights),
                    "full_reason": result_item['reason']+"_"+recall_name
                              }
    df = pd.DataFrame(pool).T.reset_index().rename(columns={"index": "result"})
    df["ranking"] = df["reason"].apply(lambda x: ranking_metric(x))
    df = df.sort_values(["ranking", "similarities"], ascending=[True, False])
    return df[['result', 'reason', 'full_reason', 'similarities']].to_dict(orient='records')

def mine_hard_negative2(x, recall, top_n, positive, output_cols):
    df = pd.DataFrame(x[recall])
    positives = x[positive].split(",")
    df = df[~df["result"].isin(positives)]
    df["ranking"] = df["reason"].apply(lambda x: ranking_metric(x))
    df = df.sort_values(["ranking", "similarities"], ascending=[True, False])
    df = df.iloc[:top_n]
    return df[output_cols].to_dict(orient='records')  

def split_recall(x, output_cols, new_cols):
    df = pd.DataFrame(x)
    return pd.Series({new_col: df[col].values.tolist() 
                      for col, new_col in zip(output_cols, new_cols)})

def find_score_limit(x):
    min_all = float("inf")
    max_all = float("-inf")
    for i in range(len(x)):
        min_i = min(x[i])
        max_i = max(x[i])
        min_all = min(min_all, min_i)
        max_all = max(max_all, max_i)
    return min_all, max_all

def convert_limit(x, min_all, max_all):
    return [(i-min_all)/(max_all-min_all) for i in x]

def convert_df_to_jsonl(df, filename, query="question_cleaned", pos_col="question_positive", neg_col="question_bge_hard"):
    with open(filename, 'w') as file:
        for _, row in df.iterrows():
            # Constructing the dictionary for each row
            data = {
                "query": row[query],
                "pos": row[pos_col],
                "neg": row[neg_col]
            }
            # Writing the JSON string followed by a newline character to make it JSONL
            file.write(json.dumps(data) + '\n')

In [8]:
# 排序
def mrr_at_k_score(is_relevant, pred_ranking, k):
    """
    Computes MRR@k score

    Args:
        is_relevant (`List[bool]` of length `num_pos+num_neg`): True if the document is relevant
        pred_ranking (`List[int]` of length `num_pos+num_neg`): Indices of the documents sorted in decreasing order
            of the similarity score

    Returns:
        mrr_score (`float`): MRR@k score
    """
    mrr_score = 0
    for rank, index in enumerate(pred_ranking[:k]):
        if is_relevant[index]:
            mrr_score = 1 / (rank + 1)
            break

    return mrr_score

def recall_at_k_score(is_relevant, pred_ranking, k):
    """
    Computes MRR@k score

    Args:
        is_relevant (`List[bool]` of length `num_pos+num_neg`): True if the document is relevant
        pred_ranking (`List[int]` of length `num_pos+num_neg`): Indices of the documents sorted in decreasing order
            of the similarity score

    Returns:
        mrr_score (`float`): MRR@k score
    """
    recall_score = 0
    for index in pred_ranking[:k]:
        if is_relevant[index]:
            recall_score = 1
            break

    return recall_score

def ap_score(is_relevant, pred_scores):
    """
    Computes AP score

    Args:
        is_relevant (`List[bool]` of length `num_pos+num_neg`): True if the document is relevant
        pred_scores (`List[float]` of length `num_pos+num_neg`): Predicted similarity scores

    Returns:
        ap_score (`float`): AP score
    """
    # preds = np.array(is_relevant)[pred_scores_argsort]
    # precision_at_k = np.mean(preds[:k])
    # ap = np.mean([np.mean(preds[: k + 1]) for k in range(len(preds)) if preds[k]])
    ap = average_precision_score(is_relevant, pred_scores)
    return ap

def compute_recall_score(df, model, query, recall):
    pairs = []
    for i in range(df.shape[0]):
        sample = df.iloc[i]
        for p in sample[recall]:
            pairs.append([sample[query], p])
    all_scores = model.compute_score(pairs)
    result = []
    start_inx = 0
    for i in range(df.shape[0]):
        sample = df.iloc[i]
        pred_scores = all_scores[start_inx:start_inx + len(sample[recall])]
        result.append(pred_scores)
        start_inx += len(sample[recall])
    return result

def compute_metrics_batched_from_crossencoder(df, score, relevant, 
                                              mrr_at_k=10, recall_at_list=[1,2], metrics=["map", "mrr", "recall"]):
    all_mrr_scores = []
    all_ap_scores = []
    all_recall_scores = [[] for _ in range(len(recall_at_list))]

    for i in range(df.shape[0]):
        sample = df.iloc[i]
        is_relevant = sample[relevant]
        pred_scores = np.array(sample[score])

        pred_scores_argsort = np.argsort(-pred_scores)  # Sort in decreasing order
        if "mrr" in metrics:
            mrr = mrr_at_k_score(is_relevant, pred_scores_argsort, mrr_at_k)
            all_mrr_scores.append(mrr)
        if "map" in metrics:
            ap = ap_score(is_relevant, pred_scores)
            all_ap_scores.append(ap)
        if "recall" in metrics:
            for recall_index, recall_at in enumerate(recall_at_list):
                recall_score = recall_at_k_score(is_relevant, pred_scores_argsort, recall_at)
                all_recall_scores[recall_index].append(recall_score)

    result = {}
    if "map" in metrics:
        mean_ap = np.mean(all_ap_scores)
        result["map"] = mean_ap
    if "mrr" in metrics:
        mean_mrr = np.mean(all_mrr_scores)
        result[f"mrr@{mrr_at_k}"] = mean_mrr
    if "recall" in metrics:
        for recall_index, recall_at in enumerate(recall_at_list):
            result[f"recall@{recall_at}"] = np.mean(all_recall_scores[recall_index])
    return result

def find_T_loc(x, relevant, score):
    is_relevant = x[relevant]
    pred_scores = np.array(x[score])
    pred_scores_argsort = np.argsort(-pred_scores)
    for rank, index in enumerate(pred_scores_argsort):
        if is_relevant[index]:
            return rank
    return np.nan

def get_reranking(x, relevant, score, recall, reason, postranking=False):
    is_relevant = x[relevant]
    pred_scores = np.array(x[score])
    pred_scores_argsort = np.argsort(-pred_scores)
    recall_list = copy.deepcopy(x[recall])
    for index, i in enumerate(recall_list):
        i.update({"relevant": is_relevant[index], "recall_order": index})
    reranking = []
    for index, i in enumerate(pred_scores_argsort):
        temp = recall_list[i]
        temp.update({"ranking_score": pred_scores[i], "ranking_order": index})
        reranking.append(temp)
    if postranking:
        reranking = pd.DataFrame(reranking)
        reranking["if_special"] = reranking[reason].apply(
            lambda x: (x.find("model") >= 0)|(x.find("error") >= 0)|(len(x.split(","))>1)).astype(int)
        reranking.loc[(reranking["if_special"]==1)&(reranking["similarities"]<=0.75), "if_special"] = 0
        reranking["ranking"] = reranking[reason].apply(lambda x: ranking_metric(x))
        top_list = reranking[reranking["if_special"]==1].sort_values(
            ["if_special", "ranking", "similarities"], ascending=[False, True, False])["result"].iloc[:2].tolist()
        reranking["if_top"] = reranking["result"].isin(top_list)
        reranking = pd.concat([reranking[reranking["if_top"]==True], reranking[reranking["if_top"]==False]], axis=0).reset_index(drop=True)
        reranking["reranking_score"] = list(range(reranking.shape[0]))[::-1]
        reranking["reranking_order"] = list(range(reranking.shape[0]))
        reranking = reranking.drop("ranking", axis=1)
        reranking = reranking.to_dict(orient="records")
    return reranking

# 向量召回

In [9]:
oot = pd.read_csv("/data/dataset/kefu/oot20240315.csv")

In [10]:
# test = pd.read_csv("/data/dataset/kefu/test20240315.csv")
# df1 = pd.read_csv("/data/dataset/kefu/database_before_online_with_emb.csv")
df2 = pd.read_csv("/data/dataset/kefu/database_with_emb20240315.csv")

In [11]:
# model = SentenceTransformer('/data/dataset/huggingface/hub/bge-large-zh-v1.5')
# model = SentenceTransformer('/workspace/data/private/zhuxiaohai/models/bge_emb_finetune_from_db/checkpoint-400')
from FlagEmbedding import FlagModel
model = FlagModel('/workspace/data/private/zhuxiaohai/models/bge_emb_finetune_from_db', 
                  query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章：",
                  use_fp16=False) 
# q_embeddings = model.encode(df1.question.tolist(), normalize_embeddings=True, batch_size=32)
# df1['bge_large'] = q_embeddings.tolist()
q_embeddings = model.encode(df2.answer.tolist(), batch_size=32)
df2['bge_large'] = q_embeddings.tolist()

Inference Embeddings: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 95/95 [01:20<00:00,  1.18it/s]


In [12]:
# df2 = df2[df2.answer.apply(lambda x: True if x.lower().strip()[:10].find("http")<0 else False)]
# df2 = df2[df2.question.apply(lambda x: True if x.lower().strip().find("开箱图")<0 else False)]

In [13]:
dim_df = pd.read_csv("/data/dataset/kefu/dim_df20240315.csv")
all_model_list = dim_df.model.tolist()
all_cat_list = dim_df.cat_name.unique().tolist()

In [14]:
test = oot.copy()
df1 = df2.copy() 

In [18]:
# test = test.sample(40)
# test["gt_qa_id"] = test['qa_id']

In [21]:
# 标签+向量3

In [15]:
top_list = range(2, 22, 2)
acc_list = []
for top_n in top_list:
    def sub_worker(result, score, reason, top_n):
        if (filter_mask & (reason_indicator.str.find("errorcode")>=0)).sum() > 0:
            aug_mask = filter_mask & (reason_indicator.str.find("errorcode")>=0)
            filtered_df = df1[aug_mask].copy()
            filtered_df["hit_reason"] = reason_indicator[aug_mask].copy()
            res = search_docs_bge(filtered_df, question, top_n=int(top_n/2), add_instruction=True)
            result += [j["qa_id"] for j in res]
            score += [round(j["similarities"], 2) for j in res]
            reason += [j["hit_reason"] for j in res]
            
            aug_mask = filter_mask & (~(reason_indicator.str.find("errorcode")>=0))
            filtered_df = df1[aug_mask].copy()
            filtered_df["hit_reason"] = reason_indicator[aug_mask].copy()
            res = search_docs_bge(filtered_df, question, top_n=int(top_n/2), add_instruction=True)
            result += [j["qa_id"] for j in res]
            score += [round(j["similarities"], 2) for j in res]
            reason += [j["hit_reason"] for j in res]
        else:
            aug_mask = filter_mask
            filtered_df = df1[aug_mask].copy()
            filtered_df["hit_reason"] = reason_indicator[aug_mask].copy()
            res = search_docs_bge(filtered_df, question, top_n=int(top_n/2), add_instruction=True)
            result += [j["qa_id"] for j in res]
            score += [round(j["similarities"], 2) for j in res]
            reason += [j["hit_reason"] for j in res]
            
            aug_mask = (reason_indicator.str.find("errorcode")>=0)
            if aug_mask.sum() > 0:
                filtered_df = df1[aug_mask].copy()
                filtered_df["hit_reason"] = reason_indicator[aug_mask].copy()
                res = search_docs_bge(filtered_df, question, top_n=int(top_n/2), add_instruction=True)
                result += [j["qa_id"] for j in res]
                score += [round(j["similarities"], 2) for j in res]
                reason += [j["hit_reason"] for j in res]
    
        aug_mask = (~filter_mask) & (reason_indicator.str.find("cat")>=0)
        if aug_mask.sum() > 0:
            filtered_df = df1[aug_mask].copy()
            filtered_df["hit_reason"] = reason_indicator[aug_mask].copy()
            res = search_docs_bge(filtered_df, question, top_n=int(top_n/2), add_instruction=True)
            result += [j["qa_id"] for j in res]
            score += [round(j["similarities"], 2) for j in res]
            reason += [j["hit_reason"] for j in res]
        return result, score, reason
    label = []
    result_list = []
    for i in range(test.shape[0]):
        gt = test['gt_qa_id'].iloc[i].split(",")
        question = test['question'].iloc[i]
        model_list = find_model(question, all_model_list)
        cat_list = find_cat(question, all_cat_list)   
        cat_list += [cat for cat in dim_df.loc[dim_df.model.isin(model_list), 'cat_name'].tolist() if cat not in cat_list]
        reason_list = find_error_with_reason(question)
        model_mask = (df1.model_list.apply(lambda x: filter_model(x, model_list)))
        cat_mask = (df1.cat_name.apply(lambda x: filter_model(x, cat_list)))
        reason_mask = (df1.question.apply(lambda x: filter_reason(x, reason_list)))
        reason_indicator = pd.Series(["none"]*df1.shape[0], index=df1.index)
        reason_indicator[model_mask] = reason_indicator[model_mask].apply(lambda x: x + "|model" if x != "none" else "model")
        reason_indicator[cat_mask] = reason_indicator[cat_mask].apply(lambda x: x + "|cat" if x != "none" else "cat")
        reason_indicator[reason_mask] = reason_indicator[reason_mask].apply(lambda x: x + "|errorcode" if x != "none" else "errorcode")
        result = []
        score = []
        reason = []
        # question = remove_model_name(question, all_model_list)
        filter_mask = (reason_indicator.str.find("model")>=0)
        if filter_mask.sum() > 0:
            result, score, reason = sub_worker(result, score, reason, top_n)
        else:
            filter_mask = (reason_indicator.str.find("cat")>=0)   
            if filter_mask.sum() > 0:
                result, score, reason = sub_worker(result, score, reason, top_n)
        if len(result) == 0:
            filter_mask = (reason_indicator.str.find("errorcode")>=0)
            if filter_mask.sum() > 0:
                aug_mask = filter_mask
                filtered_df = df1[aug_mask].copy()
                filtered_df["hit_reason"] = reason_indicator[aug_mask].copy()
                res = search_docs_bge(filtered_df, question, top_n=int(top_n/2), add_instruction=True)
                result += [j["qa_id"] for j in res]
                score += [round(j["similarities"], 2) for j in res]
                reason += [j["hit_reason"] for j in res]  
                aug_mask = (~filter_mask)
                filtered_df = df1[aug_mask].copy()
                filtered_df["hit_reason"] = reason_indicator[aug_mask].copy()
                res = search_docs_bge(filtered_df, question, top_n=int(top_n/2), add_instruction=True)
                result += [j["qa_id"] for j in res]
                score += [round(j["similarities"], 2) for j in res]
                reason += [j["hit_reason"] for j in res]  
            else:
                filtered_df = df1.copy()
                filtered_df["hit_reason"] = reason_indicator.copy()
                res = search_docs_bge(filtered_df, question, top_n=top_n, add_instruction=True)
                result += [j["qa_id"] for j in res]
                score += [round(j["similarities"], 2) for j in res]
                reason += [j["hit_reason"] for j in res]
        
        found = False
        for j in result:
            if j in gt:
                found = True
                break 
        if found:
            label.append(1)
        else:
            label.append(0)
        result_list.append({"qa_id": test['qa_id'].iloc[i], 
                            "result": result, 
                            "similarities": score, 
                            "hit_reason": reason, 
                            "label": int(found)})
    acc_list.append(sum(label)/len(label))
    print(top_n, sum(label)/len(label))

2 0.3076923076923077
4 0.5897435897435898
6 0.6153846153846154
8 0.6923076923076923
10 0.7435897435897436
12 0.7948717948717948
14 0.8205128205128205
16 0.8461538461538461
18 0.8461538461538461
20 0.8461538461538461
