In [1]:
import gc
import os
import pickle
import jsonlines
import torch
from tqdm import tqdm
from collections import defaultdict
import argparse
from core.models.entailment import EntailmentDeberta
from core.data.data_utils import load_ds_from_json
from rank_eval import eval_beir_rank_result

def load_pickle_file(file_path):
    with open(file_path, 'rb') as f:
        data = pickle.load(f)
    return data

def save_pickle_file(file_path, data):
    with open(file_path, 'wb') as f:
        pickle.dump(data, f)

In [2]:
BEIR_DATASET_NAMES = ["trec-covid", "climate-fever", "dbpedia-entity", "fever", "fiqa", "hotpotqa", "msmarco",  "nfcorpus", "nq", "scidocs", "scifact"]
SIZE_NAME = "toy"

all_scores = defaultdict(dict)

for dataset_name in tqdm(BEIR_DATASET_NAMES):
    all_scores[dataset_name] = {}
    print(f"> {dataset_name} rerank:")
    dataset_path = f'/home/song/dataset/beir/{dataset_name}'
    rerank_result_path = f'output/rerank/{dataset_name}/rerank-{SIZE_NAME}.tsv'
    print(f"rerank_result_path: {rerank_result_path}")
    rerank_scores = eval_beir_rank_result(rerank_result_path, dataset_path, dataset_name, k_values=[1,3,5,10])
    all_scores[dataset_name]["entropy"] = rerank_scores
    print(f">> {dataset_name} rank:")
    rank_result_path = f'/home/song/dataset/first/beir_rank/{dataset_name}/rank.tsv'
    rank_scores = eval_beir_rank_result(rank_result_path, dataset_path, dataset_name, k_values=[1,3,5,10])
    all_scores[dataset_name]["rank"] = rerank_scores
print("ALL DONE!")

# def eval_beir_rank_result(rank_result_path, dataset_path, dataset_name, k_values=[1,3,5,10]):

  0%|          | 0/11 [00:00<?, ?it/s]

> trec-covid rerank:
rerank_result_path: output/rerank/trec-covid/rerank-toy.tsv


  0%|          | 0/171332 [00:00<?, ?it/s]

Retriever evaluation
map:
{'MAP@1': 0.00157, 'MAP@3': 0.00371, 'MAP@5': 0.00574, 'MAP@10': 0.01046}
precision:
{'P@1': 0.7, 'P@3': 0.56667, 'P@5': 0.6, 'P@10': 0.61}
ndcg:
{'NDCG@1': 0.55, 'NDCG@3': 0.5, 'NDCG@5': 0.51163, 'NDCG@10': 0.51762}
mrr:
{'MRR@1': 0.14, 'MRR@3': 0.15667, 'MRR@5': 0.15667, 'MRR@10': 0.16}
recall_cap:
{'R_cap@1': 0.14, 'R_cap@3': 0.12, 'R_cap@5': 0.12, 'R_cap@10': 0.122}
>> trec-covid rank:


  0%|          | 0/171332 [00:00<?, ?it/s]

  9%|▉         | 1/11 [00:02<00:23,  2.33s/it]

Retriever evaluation
map:
{'MAP@1': 0.00175, 'MAP@3': 0.0051, 'MAP@5': 0.00807, 'MAP@10': 0.01358}
precision:
{'P@1': 0.72, 'P@3': 0.72, 'P@5': 0.696, 'P@10': 0.624}
ndcg:
{'NDCG@1': 0.65, 'NDCG@3': 0.67061, 'NDCG@5': 0.64957, 'NDCG@10': 0.59638}
mrr:
{'MRR@1': 0.72, 'MRR@3': 0.78667, 'MRR@5': 0.79567, 'MRR@10': 0.80471}
recall_cap:
{'R_cap@1': 0.72, 'R_cap@3': 0.72, 'R_cap@5': 0.696, 'R_cap@10': 0.624}
> climate-fever rerank:
rerank_result_path: output/rerank/climate-fever/rerank-toy.tsv


  0%|          | 0/5416593 [00:00<?, ?it/s]

Retriever evaluation
map:
{'MAP@1': 0.08333, 'MAP@3': 0.08333, 'MAP@5': 0.1025, 'MAP@10': 0.14845}
precision:
{'P@1': 0.2, 'P@3': 0.06667, 'P@5': 0.1, 'P@10': 0.13}
ndcg:
{'NDCG@1': 0.2, 'NDCG@3': 0.10824, 'NDCG@5': 0.16002, 'NDCG@10': 0.26615}
mrr:
{'MRR@1': 0.00065, 'MRR@3': 0.00098, 'MRR@5': 0.00143, 'MRR@10': 0.00174}
recall:
{'Recall@1': 0.08333, 'Recall@3': 0.08333, 'Recall@5': 0.16667, 'Recall@10': 0.395}
>> climate-fever rank:


  0%|          | 0/5416593 [00:00<?, ?it/s]

Retriever evaluation
map:
{'MAP@1': 0.09848, 'MAP@3': 0.13828, 'MAP@5': 0.1515, 'MAP@10': 0.16494}
precision:
{'P@1': 0.21629, 'P@3': 0.13833, 'P@5': 0.10788, 'P@10': 0.07472}
ndcg:
{'NDCG@1': 0.21629, 'NDCG@3': 0.18945, 'NDCG@5': 0.20541, 'NDCG@10': 0.23712}
mrr:
{'MRR@1': 0.21629, 'MRR@3': 0.28339, 'MRR@5': 0.30137, 'MRR@10': 0.31756}
recall:
{'Recall@1': 0.09848, 'Recall@3': 0.17719, 'Recall@5': 0.22004, 'Recall@10': 0.29169}


 18%|█▊        | 2/11 [00:54<04:43, 31.45s/it]

> dbpedia-entity rerank:
rerank_result_path: output/rerank/dbpedia-entity/rerank-toy.tsv


  0%|          | 0/4635922 [00:00<?, ?it/s]

Retriever evaluation
map:
{'MAP@1': 0.01843, 'MAP@3': 0.0501, 'MAP@5': 0.06144, 'MAP@10': 0.08719}
precision:
{'P@1': 0.8, 'P@3': 0.7, 'P@5': 0.6, 'P@10': 0.51}
ndcg:
{'NDCG@1': 0.7, 'NDCG@3': 0.59491, 'NDCG@5': 0.55606, 'NDCG@10': 0.49936}
mrr:
{'MRR@1': 0.02, 'MRR@3': 0.02125, 'MRR@5': 0.02187, 'MRR@10': 0.02187}
recall:
{'Recall@1': 0.01843, 'Recall@3': 0.05201, 'Recall@5': 0.06886, 'Recall@10': 0.10333}
>> dbpedia-entity rank:


  0%|          | 0/4635922 [00:00<?, ?it/s]

 18%|█▊        | 2/11 [01:16<05:44, 38.32s/it]


KeyboardInterrupt: 

In [None]:
all_scores