In [17]:
import gc
import os
import pickle
import jsonlines
import torch
from tqdm import tqdm
import pandas as pd
import csv
from collections import defaultdict
import argparse
from core.models.entailment import EntailmentDeberta
from rank_eval import load_data, load_rank_results

def merge_score(rank_score, entropy_score):
    if entropy_score is None:
        return rank_score
    if entropy_score < 0.01:
        return rank_score + 1.0
    return rank_score

dataset_names = ["trec-covid", "climate-fever", "dbpedia-entity", "fever", "hotpotqa", "nfcorpus", "nq", "scidocs"]
for dataset_name in tqdm(dataset_names, desc='dataset'):
    dataset_path = f'/home/song/dataset/beir/{dataset_name}'
    queries1, docs1, scores = load_data(dataset_path, dataset_name)
    queries = {str(qid): query['text'] for qid, query in queries1.items()}
    docs = {str(docid): doc for docid, doc in docs1.items()}
    rank_result_path = f'dataset/rank/{dataset_name}/{dataset_name}-rank10-small.tsv'
    rank_results = load_rank_results(rank_result_path)
    entropy_result_path = f'output/rerank/{dataset_name}/entropy-small.tsv'
    entropy_results = load_rank_results(entropy_result_path)
    print(f"dataset: {dataset_name}")
    merge_results = [] # ['qid', 'query', 'docid', 'doc', 'gold_score', 'rank_index', 'rank_score', 'entropy_score', 'merge_score']
    for qid in rank_results:
        for i, docid in enumerate(rank_results[qid]):
            merge_results.append([str(qid), 
                                  queries.get(str(qid), ''), 
                                  str(docid), 
                                  docs.get(str(docid), ''), 
                                  scores.get(str(qid), {}).get(str(docid), 0.0), 
                                  i,
                                  rank_results.get(qid, {}).get(docid, 0.0),
                                  entropy_results.get(qid, {}).get(docid, None),
                                  merge_score(rank_results.get(qid, {}).get(docid, 0.0), entropy_results.get(qid, {}).get(docid, None))
                                  ])
    with open(f'output/tmp/merge-small-{dataset_name}.tsv', 'w', newline='') as f:
        writer = csv.writer(f, delimiter='\t')
        writer.writerow(['qid', 'query', 'docid', 'doc', 'gold_score', 'rank_index', 'rank_score', 'entropy_score', 'merge_score'])
        writer.writerows(merge_results)
    print(f"output: output/tmp/merge-small-{dataset_name}.tsv")

dataset:   0%|          | 0/8 [00:00<?, ?it/s]

  0%|          | 0/171332 [00:00<?, ?it/s]

dataset:  12%|█▎        | 1/8 [00:01<00:09,  1.33s/it]

dataset: trec-covid
output: output/tmp/merge-small-trec-covid.tsv


  0%|          | 0/5416593 [00:00<?, ?it/s]

dataset:  25%|██▌       | 2/8 [00:26<01:30, 15.12s/it]

dataset: climate-fever
output: output/tmp/merge-small-climate-fever.tsv


  0%|          | 0/4635922 [00:00<?, ?it/s]

dataset:  38%|███▊      | 3/8 [00:48<01:33, 18.61s/it]

dataset: dbpedia-entity
output: output/tmp/merge-small-dbpedia-entity.tsv


  0%|          | 0/5416568 [00:00<?, ?it/s]

dataset:  50%|█████     | 4/8 [01:14<01:26, 21.58s/it]

dataset: fever
output: output/tmp/merge-small-fever.tsv


  0%|          | 0/5233329 [00:00<?, ?it/s]

dataset:  62%|██████▎   | 5/8 [01:39<01:07, 22.64s/it]

dataset: hotpotqa
output: output/tmp/merge-small-hotpotqa.tsv


  0%|          | 0/3633 [00:00<?, ?it/s]

dataset:  75%|███████▌  | 6/8 [01:40<00:30, 15.44s/it]

dataset: nfcorpus
output: output/tmp/merge-small-nfcorpus.tsv


  0%|          | 0/2681468 [00:00<?, ?it/s]

dataset:  88%|████████▊ | 7/8 [01:52<00:14, 14.09s/it]

dataset: nq
output: output/tmp/merge-small-nq.tsv


  0%|          | 0/25657 [00:00<?, ?it/s]

dataset: 100%|██████████| 8/8 [01:53<00:00, 14.23s/it]

dataset: scidocs
output: output/tmp/merge-small-scidocs.tsv



