In [1]:
import gc
import os
import pickle
import jsonlines
import torch
from tqdm import tqdm
from collections import defaultdict
import argparse
from core.models.entailment import EntailmentDeberta
from core.data.data_utils import load_ds_from_json
from rank_eval import eval_beir_rerank_result

def load_pickle_file(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data

def save_pickle_file(file_path, data):
    with open(file_path, 'wb') as f:
        pickle.dump(data, f)



def run_eval(dataset_names, size_name="small"):
    # BEIR_DATASET_NAMES = ["trec-covid", "climate-fever", "dbpedia-entity", "fever", "fiqa", "hotpotqa", "msmarco",  "nfcorpus", "nq", "scidocs", "scifact"]
    # SIZE_NAME = "toy"
    # SIZE_NAME = "small"

    all_scores = {}
    for dataset_name in tqdm(dataset_names):
        try:
            dataset_path = f'/home/song/dataset/beir/{dataset_name}'
            rank_result_path = f'dataset/rank/{dataset_name}/{dataset_name}-rank10-{size_name}.tsv'
            entropy_result_path = f'output/rerank/{dataset_name}/entropy-{size_name}.tsv'
            all_scores[dataset_name] = eval_beir_rerank_result(rank_result_path, entropy_result_path, dataset_path, dataset_name, k_values=[1,3,5,10])
        except Exception as e:
            print(f"Error: {e}")
    # Save all_scores
    save_pickle_file(f"output/rerank/entropy_scores_{size_name}.pkl", all_scores)
    return all_scores
# all_scores = load_pickle_file('output/rerank/entropy_scores_small.pkl')

def calc_avg_score(all_scores, dataset_names, methods, all_metrics):
    # 利用numpy，将all_scores建立高维数组，[指标][方法][数据集]
    import numpy as np
    score_array = np.zeros((len(all_metrics), len(methods), len(dataset_names)))
    for i, dataset_name in enumerate(dataset_names):
        for j, method in enumerate(methods):
            for k, (metric1, metric2) in enumerate(all_metrics):
                try:
                    score_array[k, j, i] = all_scores[dataset_name][method][metric1][metric2]
                except Exception as e:
                    pass
                    # print(e)
                    # print(f"Error in {dataset_name}")
    print(score_array.shape)

    # 将score_array转换为DataFrame，将方法名称和指标名称（all_metrics的第二个元素）作为行列索引，数据集这列取平均值
    import pandas as pd
    # df = pd.DataFrame(score_array.mean(axis=-1), index=all_metrics, columns=methods)
    df = pd.DataFrame(score_array.mean(axis=-1), index=[m[1] for m in all_metrics], columns=methods)
    return df


dataset_names = ["trec-covid", "climate-fever", "dbpedia-entity", "fever", "hotpotqa", "nfcorpus", "nq", "scidocs"]
methods = ["rank", "entropy", "rerank"]
all_metrics = [('map', 'MAP@1'), ('map', 'MAP@10'), ('map', 'MAP@3'), ('map', 'MAP@5'), ('mrr', 'MRR@1'), ('mrr', 'MRR@10'), ('mrr', 'MRR@3'), ('mrr', 'MRR@5'), ('ndcg', 'NDCG@1'), ('ndcg', 'NDCG@10'), ('ndcg', 'NDCG@3'), ('ndcg', 'NDCG@5'), ('precision', 'P@1'), ('precision', 'P@10'), ('precision', 'P@3'), ('precision', 'P@5'), ('recall', 'Recall@1'), ('recall', 'Recall@10'), ('recall', 'Recall@3'), ('recall', 'Recall@5'), ('recall_cap', 'R_cap@1'), ('recall_cap', 'R_cap@10'), ('recall_cap', 'R_cap@3'), ('recall_cap', 'R_cap@5')]


all_scores = run_eval(dataset_names)
df = calc_avg_score(all_scores, dataset_names, methods, all_metrics)
# 过滤掉不需要的指标，只保留@5结尾的
df[df.index.str.endswith('@5')]

  0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/171332 [00:00<?, ?it/s]

 12%|█▎        | 1/8 [00:01<00:08,  1.24s/it]

Success count: 500, success1 count: 14, fail count: 0


  0%|          | 0/5416593 [00:00<?, ?it/s]

 25%|██▌       | 2/8 [00:25<01:30, 15.06s/it]

Success count: 40, success1 count: 2, fail count: 460


  0%|          | 0/4635922 [00:00<?, ?it/s]

 38%|███▊      | 3/8 [00:45<01:24, 16.91s/it]

Success count: 10, success1 count: 0, fail count: 490


  0%|          | 0/5416568 [00:00<?, ?it/s]

 50%|█████     | 4/8 [01:08<01:18, 19.56s/it]

Success count: 0, success1 count: 0, fail count: 500


  0%|          | 0/5233329 [00:00<?, ?it/s]

 62%|██████▎   | 5/8 [01:29<01:00, 20.08s/it]

Success count: 0, success1 count: 0, fail count: 500


  0%|          | 0/3633 [00:00<?, ?it/s]

 75%|███████▌  | 6/8 [01:29<00:26, 13.30s/it]

Success count: 10, success1 count: 0, fail count: 490


  0%|          | 0/2681468 [00:00<?, ?it/s]

 88%|████████▊ | 7/8 [01:40<00:12, 12.48s/it]

Success count: 10, success1 count: 0, fail count: 490


  0%|          | 0/25657 [00:00<?, ?it/s]

100%|██████████| 8/8 [01:41<00:00, 12.70s/it]

Success count: 0, success1 count: 0, fail count: 500
(24, 3, 8)





Unnamed: 0,rank,entropy,rerank
MAP@5,0.283144,0.065489,0.283144
MRR@5,0.128754,0.109799,0.128754
NDCG@5,0.459321,0.217816,0.459321
P@5,0.2785,0.200887,0.2785
Recall@5,0.346438,0.146546,0.346438
R_cap@5,0.087,0.077,0.087


In [2]:
# 过滤掉不需要的指标，只保留@10结尾的
df[df.index.str.endswith('@10')]

Unnamed: 0,rank,entropy,rerank
MAP@10,0.295324,0.105758,0.295324
MRR@10,0.130358,0.112075,0.130358
NDCG@10,0.46088,0.306845,0.46088
P@10,0.20525,0.217044,0.20525
Recall@10,0.391511,0.399086,0.391511
R_cap@10,0.078,0.078,0.078


In [3]:
# 过滤掉不需要的指标，只保留@1结尾的
df[df.index.str.endswith('@1')]

Unnamed: 0,rank,entropy,rerank
MAP@1,0.208441,0.014234,0.208441
MRR@1,0.115706,0.091201,0.115706
NDCG@1,0.48625,0.143776,0.48625
P@1,0.5125,0.159415,0.5125
Recall@1,0.208217,0.014059,0.208217
R_cap@1,0.0925,0.0825,0.0925


In [4]:
# 过滤掉不需要的指标，只保留@3结尾的
df[df.index.str.endswith('@3')]

Unnamed: 0,rank,entropy,rerank
MAP@3,0.267944,0.051774,0.267944
MRR@3,0.12701,0.106375,0.12701
NDCG@3,0.46195,0.199419,0.46195
P@3,0.345834,0.205226,0.345834
Recall@3,0.306573,0.095734,0.306573
R_cap@3,0.088334,0.078334,0.088334
